Migration Guide from version 2.x to 3.x

This guide is intended to help you migrate your project from version 2.x to 3.x of RxInfer. The main difference between these two versions is the redefinition of the model specification language. A detailed explanation of the new model definition language can be found in the GraphPPL documentation. Here, we will give an overview of the most important changes and introduce RxInfer specific changes.

Model specification

Also read the Model specification guide.

randomvar, datavar and constvar have been removed

The most notable change in the model specification is the removal of the randomvar, datavar, and constvar functions. Now, the @model macro automatically determines whether to use randomvar or constvar based on their usage. Previously declared datavar variables must now be listed in the argument list of the model.

The following example is a simple model definition in version 3:

@model function SSM(n, x0, A, B, Q, P) 
    x = randomvar(n) 
    y = datavar(Vector{Float64}, n) 
    x_prior ~ MvNormal(μ = mean(x0), Σ = cov(x0)) 
    x_prev = x_prior 
    for i in 1:n 
        x[i] ~ MvNormal(μ = A * x_prev, Σ = Q) 
        y[i] ~ MvNormal(μ = B * x[i], Σ = P) 
        x_prev = x[i] 
    end 
end 

The equivalent model definition in version 4 is as follows:

@model function SSM(y, prior_x, A, B, Q, P) 
    x_prev ~ prior_x
    for i in eachindex(y)
        x[i] ~ MvNormal(μ = A * x_prev, Σ = Q) 
        y[i] ~ MvNormal(μ = B * x[i], Σ = P) 
        x_prev = x[i]
    end
end

Read more about the change in the GraphPPL documentation and in the updated Model specification guide.

Positional arguments are converted to keyword arguments

The changes in the model specification also have implications for the infer function. Since all interfaces to a model are now passed as arguments to the @model macro, the infer function needs additional information on model construction. Therefore, the model function definition converts all positional arguments to keyword arguments. Positional arguments are no longer supported in the model function definition. Below is an example of the new model definition:

using RxInfer

@model function coin_toss(prior, y)
    θ ~ prior
    y .~ Bernoulli(θ)
end

# Here, we pass a prior as a parameter to the model, and the data `y` is passed as data.
# Since we have to distinguish between what should be used as which argument, we have to pass the data as a keyword argument.
infer(
    model = coin_toss(prior = Beta(1, 1)),
    data  = (y = [1, 0, 1],)
)
Inference results:
  Posteriors       | available for (θ)

Multiple dispatch is no longer supported

Due to the previous change, it is not possible to use multiple dispatch for model function definitions. In other words, type constraints for model arguments are ignored because Julia does not support multiple dispatch for keyword arguments.

Return value from the model function

Accessing the return value of the model function has changed. Previously, the return value was returned together with the model upon creation. Now, the return value is saved in the model's data structure, which can be accessed with the RxInfer.getreturnval function. To demonstrate the difference, previously we could do the following:

@model function test_model(a, b)
    y = datavar(Float64)
    θ ~ Beta(1.0, 1.0)
    y ~ Bernoulli(θ)
    return "Hello, world!"
end
modelgenerator = test_model(1.0, 1.0)
model, returnval = RxInfer.create_model(modelgenerator)
returnval # "Hello, world!"

The new API is changed to:

modelgenerator = test_model(a = 1.0, b = 1.0) | (y = 1, )
model = RxInfer.create_model(modelgenerator)
RxInfer.getreturnval(model)
"Hello, world!"

The InferenceResult also no longer stores the returnval field. Instead, use the model field and the RxInfer.getreturnval function:

result = infer(
    model = test_model(a = 1.0, b = 1.0),
    data  = (y = 1, )
)
RxInfer.getreturnval(result.model)
"Hello, world!"

Returning variables from the model

Similar to the previous version, you can still return latent variables from the model definition:

@model function test_model(y, a, b)
    θ ~ Beta(1.0, 1.0)
    y ~ Bernoulli(θ)
    return θ
end

However, their type has changed to internal data structures from the GraphPPL package. To access the ReactiveMP data structures (e.g., to retrieve the messages or marginals streams), use RxInfer.getvarref along with RxInfer.getvariable:

using ReactiveMP, Rocket
result = infer(
    model = test_model(a = 1.0, b = 1.0),
    data  = (y = 1, )
)

θlabel  = RxInfer.getreturnval(result.model)
θvarref = RxInfer.getvarref(result.model, θlabel)
θvar    = RxInfer.getvariable(θvarref)

# `|> take(1)` ensures automatic unsubscription
θmarginals_subscription = subscribe!(ReactiveMP.getmarginal(θvar) |> take(1), (qθ) -> println(qθ))
Marginal(Beta{Float64}(α=2.0, β=1.0))

Initialization

Initialization of messages and marginals to kickstart the inference procedure was previously done with the initmessages and initmarginals keyword. With the introduction of a nested model specificiation in the @model macro, we now need a more specific way to initialize messages and marginals. This is done with the new @initialization macro. Read more about the new syntax in the Initialization guide.