This example has been auto-generated from the examples/ folder at GitHub repository.

Probit Model (EP)

# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate(".."); Pkg.instantiate();

Estimation of pollutant

Mortality $y_t$ of fishs in a lake is observed over time. Mortality rate $\text{Ber}(\Phi(x_t))$ is linked to the level of pollutant $x_t$ in the lake according to the probit model (see below). The municipality wants to keep track of the pollution. To do so, the level of pollutant in the lake is tracked over time through observations of the fishs.


Probit model aims to infer a random proces value from noisy binary observations of it. RxInfer comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:

\[\begin{aligned} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{aligned}\]

Import packages

using RxInfer, GraphPPL,StableRNGs, Random, Plots, Distributions
using StatsFuns: normcdf

Data generation

function generate_data(nr_samples::Int64; seed = 123)
    rng = StableRNG(seed)
    # hyper parameters
    u = 0.1

    # allocate space for data
    data_x = zeros(nr_samples + 1)
    data_y = zeros(nr_samples)
    # initialize data
    data_x[1] = -2
    # generate data
    for k in eachindex(data_y)
        # calculate new x
        data_x[k+1] = data_x[k] + u + sqrt(0.01)*randn(rng)
        # calculate y
        data_y[k] = normcdf(data_x[k+1]) > rand(rng)
    # return data
    return data_x, data_y
n = 40
data_x, data_y = generate_data(n);
p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")

Model specification

@model function probit_model(y, prior_x)
    # specify uninformative prior
    x_prev ~ prior_x
    # create model 
    for k in eachindex(y)
        x[k] ~ Normal(mean = x_prev + 0.1, precision = 100)
        y[k] ~ Probit(x[k]) where {
            # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
            # To change initial value user may specify it manually, like. Changes to the initial message may improve stability in some situations
            dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 0.01))
        x_prev = x[k]

Probit Node

Probit node needs an initialisation of the 'in' message because of this computation methodology. The input message is not directly calculated. First the marginal $q(in)$ is computed and then the output message, this using the margianalisation formula.

\[\overrightarrow{\mu}(x) \overleftarrow{\mu}(x) = q(x)\]

Consequently an initial message $\overleftarrow{\mu}(in)$ is needed to start iterate. It can be speficied as in the above example. Otherwise RxInfer will initiate it at a default value.


result = infer(
    model = probit_model(prior_x=Normal(0.0, 100.0)), 
    data  = (y = data_y, ), 
    iterations = 5, 
    returnvars = (x = KeepLast(),),
    free_energy  = true
Inference results:
  Posteriors       | available for (x)
  Free Energy:     | Real[25.6698, 18.0157, 17.9199, 17.9194, 17.9194]


mx = result.posteriors[:x]

p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")

f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy, label = "Bethe Free Energy")

plot(p, f, size = (800, 400))