Autoupdates specification

RxInfer.@autoupdatesMacro
@autoupdates [ options... ] begin 
    argument_to_update = some_function(q(some_variable_from_the_model))
end

Creates the auto-updates specification for the infer function for the online-streaming Bayesian inference procedure, where it is important to update prior states based on the new updated posteriors. Read more information about the @autoupdates syntax in the official documentation.

source

RxInfer supports streaming inference on infinite datastreams, wherein posterior beliefs over latent states update automatically as soon as new observations are available. However, we also aim to update our priors given updated beliefs. Let's begin with a simple example:

using RxInfer

@model function streaming_beta_bernoulli(a, b, y)
    θ ~ Beta(a, b)
    y ~ Bernoulli(θ)
end

For this model, the RxInfer engine will update the posterior belief over the variable θ every time we receive a new observation y. However, we also wish to update our prior belief by adjusting the arguments a and b as soon as we have a new belief for the variable θ. The @autoupdates macro automates this process, simplifying the task of writing automatic updates for certain model arguments based on new beliefs within the model. Here's how it could look:

autoupdates = @autoupdates begin
    a, b = params(q(θ))
end
@autoupdates begin
    (a, b) = params(q(θ))
end

This specification directs the RxInfer inference engine to update a and b by invoking the params function on the posterior q of θ. The params function, defined in the Distributions.jl package, extracts the parameters (a and b in this case) in the form of a tuple of the resulting posterior (Beta) distribution.

Now, we can use the autoupdates structure in the infer function as following:

# The streaming inference supports static datasets as well
data = (y = [ 1, 0, 1 ], )

result = infer(
    model          = streaming_beta_bernoulli(),
    autoupdates    = autoupdates,
    data           = data,
    keephistory    = 3,
    initialization = @initialization(q(θ) = Beta(1, 1))
)

result.history[:θ]
3-element DataStructures.CircularBuffer{Any}:
 Beta{Float64}(α=2.0, β=1.0)
 Beta{Float64}(α=2.0, β=2.0)
 Beta{Float64}(α=3.0, β=2.0)

In this example, we also used the initialization keyword argument. This is required for latent states, which are used in the @autoupdates specification together with streaming inference.

Consider another example with the following model and auto-update specification:

@model function kalman_filter(y, x_current_mean, x_current_var)
    x_current ~ Normal(mean = x_current_mean, var = x_current_var)
    x_next    ~ Normal(mean = x_current, var = 1.0)
    y         ~ Normal(mean = x_next, var = 1.0)
end

This model comprises two arguments representing our prior knowledge of the x_current state of the system. The latent state x_next represents the subsequent state of the system, linked to the observed variable y. An auto-update specification could resemble the following:

autoupdates = @autoupdates begin
    x_current_mean = mean(q(x_next))
    x_current_var  = var(q(x_next))
end
@autoupdates begin
    x_current_mean = mean(q(x_next))
    x_current_var = var(q(x_next))
end

This structure dictates updating our prior immediately upon obtaining a new posterior q(x_next). It then applies the mean and var functions to the updated posteriors, thereby automatically updating x_current_mean and x_current_var.

result = infer(
    model = kalman_filter(),
    data  = (y = rand(3), ),
    autoupdates = autoupdates,
    initialization = @initialization(q(x_next) = NormalMeanVariance(0, 1)),
    keephistory = 3,
)
result.history[:x_next]
3-element DataStructures.CircularBuffer{Any}:
 NormalWeightedMeanPrecision{Float64}(xi=0.7016294806240742, w=1.5)
 NormalWeightedMeanPrecision{Float64}(xi=1.1694120758871198, w=1.6)
 NormalWeightedMeanPrecision{Float64}(xi=0.973904078503759, w=1.6153846153846154)

Read more about streaming inference in the Streaming (online) inference section.

General syntax

The @autoupdates macro accepts either a block of code or a full function definition. It detects and transforms lines structured as follows:

(model_arguments...) = some_function(model_variables...)

These lines are referred to as individual autoupdate specifications. Other expressions remain unchanged. The result of the macro execution is the RxInfer.AutoUpdateSpecification structure that holds the collection of RxInfer.IndividualAutoUpdateSpecification.

The @autoupdates macro identifies an individual autoupdate specification if the model_variables... contains:

  • q(s), which monitors updates from marginal posteriors of an individual variable s or a collection of variables s.
  • q(s[i]), which monitors updates from marginal posteriors of the collection of variables s at index i.

Expressions not meeting the above criteria remain unmodified. For instance, an expression like a = f(1) is not considered an individual autoupdate. Therefore, the @autoupdates macro can contain arbitrary expressions and allows for the definition of temporary variables or even functions. Additionally, within an individual autoupdate specification, it is possible to use any intermediate constants, such as a, b = some_function(q(s), a_constant).

The model_arguments... can either be a single model argument or a tuple of model arguments, as defined within the @model macro. However, it's important to note that if model_arguments... is a tuple, for example in a, b = some_function(q(s)), then some_function must also return a tuple of the same length (of length 2 in this example).

Individual autoupdate specifications can involve somewhat complex expressions, as demonstrated below:

@autoupdates begin
    a = mean(q(θ)) / 2
    b = 2 * (mean(q(θ)) + 1)
end
@autoupdates begin
    a = /(mean(q(θ)), 2)
    b = *(2, +(mean(q(θ)), 1))
end

or

@autoupdates begin
    x = clamp(mean(q(z)), 0, 1)
end
@autoupdates begin
    x = clamp(mean(q(z)), 0, 1)
end
Warning

q(θ)[i] or f(q(θ))[i] syntax is not supported, use getindex(q(θ), i) or getindex(f(q(θ)), i) instead.

The @autoupdates macro does also support the broadcasting:

@autoupdates begin
    x = clamp.(mean.(q(z)), 0, 1)
end
@autoupdates begin
    x = clamp.(mean.(q(z)), 0, 1)
end

Read more about broadcasting in the official Julia documentation.

An individual autoupdate can also simultaneously depend on multiple latent states, e.g:

@autoupdates begin
    a = f(q(μ), q(s), q(τ))
    b = g(q(θ))
end
@autoupdates begin
    a = f(q(μ), q(s), q(τ))
    b = g(q(θ))
end

As mentioned before, the @autoupdates accepts a full function definition, which can also accepts arbitrary arguments:

@autoupdates function generate_autoupdates(f, condition)
    if condition
        a = f(q(θ))
    else
        a = f(q(s))
    end
end

autoupdates = generate_autoupdates(mean, true)
@autoupdates begin
    a = mean(q(θ))
end

The options block

Optionally, the @autoupdates macro accepts a set of [ options... ] before the main block or the full function definition. The available options are:

  • warn = true/false: Enables or disables warnings when with incomaptible model. Set to true by default.
  • strict = true/false: Turns warnings into errors. Set to false by default.
autoupdates = @autoupdates [ strict = true ] begin
    a, b = params(q(θ))
end
@autoupdates begin
    (a, b) = params(q(θ))
end

or

@autoupdates [ strict = true ] function generate_autoupdates()
    a, b = params(q(θ))
end
autoupdates = generate_autoupdates()
@autoupdates begin
    (a, b) = params(q(θ))
end

Internal data structures

RxInfer.AutoUpdateSpecificationType
AutoUpdateSpecification(specifications)

A structure that holds a collection of individual auto-update specifications. Each specification defines how to update the model's arguments based on the new posterior/messages updates.

source
RxInfer.parse_autoupdatesFunction
parse_autoupdates(options, expression)

Parses the internals of the expression passed to the @autoupdates macro and returns the RxInfer.AutoUpdateSpecification structure.

source
RxInfer.getvarlabelsFunction

Returns the labels of the auto-update specification, which are the names of the variables to update

source

Returns the labels of the auto-update specification, which are the names of the variables to update

source
RxInfer.IndividualAutoUpdateSpecificationType
IndividualAutoUpdateSpecification(varlabels, arguments, mapping)

A structure that defines how to update a single variable in the model. It consists of the variable labels and the mapping function.

source
RxInfer.getmappingFunction

Returns the mapping function of the auto-update specification, which defines how to update the variable

source
RxInfer.AutoUpdateVariableLabelType
AutoUpdateVariableLabel{L, I}(label, [ index = nothing ])

A structure that holds the label of the variable to update and its index. By default, the index is set to nothing.

source
RxInfer.AutoUpdateMappingType
AutoUpdateMapping(arguments, mappingFn)

A structure that holds the arguments and the mapping function for the individual auto-update specification.

source
RxInfer.prepare_autoupdates_for_modelFunction
prepare_autoupdates_for_model(autoupdates, model)

This function extracts the variables saved in the autoupdates from the model. Replaces AutoUpdateFetchMarginalArgument and AutoUpdateFetchMessageArgument with actual streams.

source