Autoupdates specification
RxInfer.@autoupdates
— Macro@autoupdates [ options... ] begin
argument_to_update = some_function(q(some_variable_from_the_model))
end
Creates the auto-updates specification for the infer
function for the online-streaming Bayesian inference procedure, where it is important to update prior states based on the new updated posteriors. Read more information about the @autoupdates
syntax in the official documentation.
RxInfer
supports streaming inference on infinite datastreams, wherein posterior beliefs over latent states update automatically as soon as new observations are available. However, we also aim to update our priors given updated beliefs. Let's begin with a simple example:
using RxInfer
@model function streaming_beta_bernoulli(a, b, y)
θ ~ Beta(a, b)
y ~ Bernoulli(θ)
end
For this model, the RxInfer
engine will update the posterior belief over the variable θ
every time we receive a new observation y
. However, we also wish to update our prior belief by adjusting the arguments a
and b
as soon as we have a new belief for the variable θ
. The @autoupdates
macro automates this process, simplifying the task of writing automatic updates for certain model arguments based on new beliefs within the model. Here's how it could look:
autoupdates = @autoupdates begin
a, b = params(q(θ))
end
@autoupdates begin
(a, b) = params(q(θ))
end
This specification directs the RxInfer
inference engine to update a
and b
by invoking the params
function on the posterior q
of θ
. The params
function, defined in the Distributions.jl
package, extracts the parameters (a
and b
in this case) in the form of a tuple of the resulting posterior (Beta
) distribution.
Now, we can use the autoupdates
structure in the infer
function as following:
# The streaming inference supports static datasets as well
data = (y = [ 1, 0, 1 ], )
result = infer(
model = streaming_beta_bernoulli(),
autoupdates = autoupdates,
data = data,
keephistory = 3,
initialization = @initialization(q(θ) = Beta(1, 1))
)
result.history[:θ]
3-element DataStructures.CircularBuffer{Any}:
Beta{Float64}(α=2.0, β=1.0)
Beta{Float64}(α=2.0, β=2.0)
Beta{Float64}(α=3.0, β=2.0)
In this example, we also used the initialization keyword argument. This is required for latent states, which are used in the @autoupdates
specification together with streaming inference.
Consider another example with the following model and auto-update specification:
@model function kalman_filter(y, x_current_mean, x_current_var)
x_current ~ Normal(mean = x_current_mean, var = x_current_var)
x_next ~ Normal(mean = x_current, var = 1.0)
y ~ Normal(mean = x_next, var = 1.0)
end
This model comprises two arguments representing our prior knowledge of the x_current
state of the system. The latent state x_next
represents the subsequent state of the system, linked to the observed variable y
. An auto-update specification could resemble the following:
autoupdates = @autoupdates begin
x_current_mean = mean(q(x_next))
x_current_var = var(q(x_next))
end
@autoupdates begin
x_current_mean = mean(q(x_next))
x_current_var = var(q(x_next))
end
This structure dictates updating our prior immediately upon obtaining a new posterior q(x_next)
. It then applies the mean
and var
functions to the updated posteriors, thereby automatically updating x_current_mean
and x_current_var
.
result = infer(
model = kalman_filter(),
data = (y = rand(3), ),
autoupdates = autoupdates,
initialization = @initialization(q(x_next) = NormalMeanVariance(0, 1)),
keephistory = 3,
)
result.history[:x_next]
3-element DataStructures.CircularBuffer{Any}:
NormalWeightedMeanPrecision{Float64}(xi=0.7016294806240742, w=1.5)
NormalWeightedMeanPrecision{Float64}(xi=1.1694120758871198, w=1.6)
NormalWeightedMeanPrecision{Float64}(xi=0.973904078503759, w=1.6153846153846154)
Read more about streaming inference in the Streaming (online) inference section.
General syntax
The @autoupdates
macro accepts either a block of code or a full function definition. It detects and transforms lines structured as follows:
(model_arguments...) = some_function(model_variables...)
These lines are referred to as individual autoupdate specifications. Other expressions remain unchanged. The result of the macro execution is the RxInfer.AutoUpdateSpecification
structure that holds the collection of RxInfer.IndividualAutoUpdateSpecification
.
The @autoupdates
macro identifies an individual autoupdate specification if the model_variables...
contains:
q(s)
, which monitors updates from marginal posteriors of an individual variables
or a collection of variabless
.q(s[i])
, which monitors updates from marginal posteriors of the collection of variabless
at indexi
.
Expressions not meeting the above criteria remain unmodified. For instance, an expression like a = f(1)
is not considered an individual autoupdate. Therefore, the @autoupdates
macro can contain arbitrary expressions and allows for the definition of temporary variables or even functions. Additionally, within an individual autoupdate specification, it is possible to use any intermediate constants, such as a, b = some_function(q(s), a_constant)
.
The model_arguments...
can either be a single model argument or a tuple of model arguments, as defined within the @model
macro. However, it's important to note that if model_arguments...
is a tuple, for example in a, b = some_function(q(s))
, then some_function
must also return a tuple of the same length (of length 2 in this example).
Individual autoupdate specifications can involve somewhat complex expressions, as demonstrated below:
@autoupdates begin
a = mean(q(θ)) / 2
b = 2 * (mean(q(θ)) + 1)
end
@autoupdates begin
a = /(mean(q(θ)), 2)
b = *(2, +(mean(q(θ)), 1))
end
or
@autoupdates begin
x = clamp(mean(q(z)), 0, 1)
end
@autoupdates begin
x = clamp(mean(q(z)), 0, 1)
end
q(θ)[i]
or f(q(θ))[i]
syntax is not supported, use getindex(q(θ), i)
or getindex(f(q(θ)), i)
instead.
The @autoupdates
macro does also support the broadcasting:
@autoupdates begin
x = clamp.(mean.(q(z)), 0, 1)
end
@autoupdates begin
x = clamp.(mean.(q(z)), 0, 1)
end
Read more about broadcasting in the official Julia documentation.
An individual autoupdate can also simultaneously depend on multiple latent states, e.g:
@autoupdates begin
a = f(q(μ), q(s), q(τ))
b = g(q(θ))
end
@autoupdates begin
a = f(q(μ), q(s), q(τ))
b = g(q(θ))
end
As mentioned before, the @autoupdates
accepts a full function definition, which can also accepts arbitrary arguments:
@autoupdates function generate_autoupdates(f, condition)
if condition
a = f(q(θ))
else
a = f(q(s))
end
end
autoupdates = generate_autoupdates(mean, true)
@autoupdates begin
a = mean(q(θ))
end
The options block
Optionally, the @autoupdates
macro accepts a set of [ options... ]
before the main block or the full function definition. The available options are:
warn = true/false
: Enables or disables warnings when with incomaptible model. Set totrue
by default.strict = true/false
: Turns warnings into errors. Set tofalse
by default.
autoupdates = @autoupdates [ strict = true ] begin
a, b = params(q(θ))
end
@autoupdates begin
(a, b) = params(q(θ))
end
or
@autoupdates [ strict = true ] function generate_autoupdates()
a, b = params(q(θ))
end
autoupdates = generate_autoupdates()
@autoupdates begin
(a, b) = params(q(θ))
end
Internal data structures
RxInfer.AutoUpdateSpecification
— TypeAutoUpdateSpecification(specifications)
A structure that holds a collection of individual auto-update specifications. Each specification defines how to update the model's arguments based on the new posterior/messages updates.
RxInfer.parse_autoupdates
— Functionparse_autoupdates(options, expression)
Parses the internals of the expression passed to the @autoupdates
macro and returns the RxInfer.AutoUpdateSpecification
structure.
RxInfer.autoupdate_check_reserved_expressions
— Functionautoupdate_check_reserved_expressions(block)
This function checks if the expression is a valid autoupdate specification some expressions are forbidden within the autoupdate specification.
RxInfer.numautoupdates
— FunctionReturns the number of auto-updates in the specification
Base.isempty
— MethodReturns true
if the auto-update specification is empty
RxInfer.getautoupdate
— FunctionReturns the individual auto-update specification at the given index
RxInfer.addspecification
— FunctionAppends the individual auto-update specification to the existing specification
RxInfer.getvarlabels
— FunctionReturns the labels of the auto-update specification, which are the names of the variables to update
Returns the labels of the auto-update specification, which are the names of the variables to update
RxInfer.IndividualAutoUpdateSpecification
— TypeIndividualAutoUpdateSpecification(varlabels, arguments, mapping)
A structure that defines how to update a single variable in the model. It consists of the variable labels and the mapping function.
RxInfer.getmapping
— FunctionReturns the mapping function of the auto-update specification, which defines how to update the variable
RxInfer.AutoUpdateVariableLabel
— TypeAutoUpdateVariableLabel{L, I}(label, [ index = nothing ])
A structure that holds the label of the variable to update and its index. By default, the index is set to nothing
.
RxInfer.AutoUpdateMapping
— TypeAutoUpdateMapping(arguments, mappingFn)
A structure that holds the arguments and the mapping function for the individual auto-update specification.
RxInfer.AutoUpdateFetchMarginalArgument
— TypeThis autoupdate would fetch updates from the marginal of a variable
RxInfer.AutoUpdateFetchMessageArgument
— TypeThis autoupdate would fetch updates from the last message (in the array of messages) of a variable
RxInfer.prepare_autoupdates_for_model
— Functionprepare_autoupdates_for_model(autoupdates, model)
This function extracts the variables saved in the autoupdates
from the model. Replaces AutoUpdateFetchMarginalArgument
and AutoUpdateFetchMessageArgument
with actual streams.