This example has been auto-generated from the examples/ folder at GitHub repository.

Hierarchical Gaussian Filter

# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate(".."); Pkg.instantiate();

In this demo the goal is to perform approximate variational Bayesian Inference for Univariate Hierarchical Gaussian Filter (HGF).

Simple HGF model can be defined as:

\[\begin{aligned} x^{(j)}_k & \sim \, \mathcal{N}(x^{(j)}_{k - 1}, f_k(x^{(j - 1)}_k)) \\ y_k & \sim \, \mathcal{N}(x^{(j)}_k, \tau_k) \end{aligned}\]

where $j$ is an index of layer in hierarchy, $k$ is a time step and $f_k$ is a variance activation function. RxInfer.jl export Gaussian Controlled Variance (GCV) node with $f_k = \exp(\kappa x + \omega)$ variance activation function. By default the node uses Gauss-Hermite cubature with a prespecified number of approximation points in the cubature. In this demo we also show how we can change the hyperparameters in different approximation methods (iin this case Gauss-Hermite cubature) with the help of metadata structures. Here how our model will look like with the GCV node:

\[\begin{aligned} z_k & \sim \, \mathcal{N}(z_{k - 1}, \mathcal{\tau_z}) \\ x_k & \sim \, \mathcal{N}(x_{k - 1}, \exp(\kappa z_k + \omega)) \\ y_k & \sim \, \mathcal{N}(x_k, \mathcal{\tau_y}) \end{aligned}\]

In this experiment we will create a single time step of the graph and perform variational message passing filtering alrogithm to estimate hidden states of the system. For a more rigorous introduction to Hierarchical Gaussian Filter we refer to Ismail Senoz, Online Message Passing-based Inference in the Hierarchical Gaussian Filter paper.

For simplicity we will consider $\tau_z$, $\tau_y$, $\kappa$ and $\omega$ known and fixed, but there are no principled limitations to make them random variables too.

To model this process in RxInfer, first, we start with importing all needed packages:

using RxInfer, BenchmarkTools, Random, Plots

Next step, is to generate some synthetic data:

function generate_data(rng, k, w, zv, yv)
    z_prev = 0.0
    x_prev = 0.0

    z = Vector{Float64}(undef, n)
    v = Vector{Float64}(undef, n)
    x = Vector{Float64}(undef, n)
    y = Vector{Float64}(undef, n)

    for i in 1:n
        z[i] = rand(rng, Normal(z_prev, sqrt(zv)))
        v[i] = exp(k * z[i] + w)
        x[i] = rand(rng, Normal(x_prev, sqrt(v[i])))
        y[i] = rand(rng, Normal(x[i], sqrt(yv)))

        z_prev = z[i]
        x_prev = x[i]
    return z, x, y
generate_data (generic function with 1 method)
# Seed for reproducibility
seed = 42

rng = MersenneTwister(seed)

# Parameters of HGF process
real_k = 1.0
real_w = 0.0
z_variance = abs2(0.2)
y_variance = abs2(0.1)

# Number of observations
n = 300

z, x, y = generate_data(rng, real_k, real_w, z_variance, y_variance);

Let's plot our synthetic dataset. Lines represent our hidden states we want to estimate using noisy observations.

    pz = plot(title = "Hidden States Z")
    px = plot(title = "Hidden States X")
    plot!(pz, 1:n, z, label = "z_i", color = :orange)
    plot!(px, 1:n, x, label = "x_i", color = :green)
    scatter!(px, 1:n, y, label = "y_i", color = :red, ms = 2, alpha = 0.2)
    plot(pz, px, layout = @layout([ a; b ]))

To create a model we use the @model macro:

# We create a single-time step of corresponding state-space process to
# perform online learning (filtering)
@model function hgf(real_k, real_w, z_variance, y_variance)
    # Priors from previous time step for `z`
    zt_min_mean = datavar(Float64)
    zt_min_var  = datavar(Float64)
    # Priors from previous time step for `x`
    xt_min_mean = datavar(Float64)
    xt_min_var  = datavar(Float64)

    zt_min ~ NormalMeanVariance(zt_min_mean, zt_min_var)
    xt_min ~ NormalMeanVariance(xt_min_mean, xt_min_var)

    # Higher layer is modelled as a random walk 
    zt ~ NormalMeanVariance(zt_min, z_variance)
    # Lower layer is modelled with `GCV` node
    gcvnode, xt ~ GCV(xt_min, zt, real_k, real_w)
    # Noisy observations 
    y = datavar(Float64)
    y ~ NormalMeanVariance(xt, y_variance)
    return gcvnode

@constraints function hgfconstraints() 
    q(xt, zt, xt_min) = q(xt, xt_min)q(zt)

@meta function hgfmeta()
    # Lets use 31 approximation points in the Gauss Hermite cubature approximation method
    GCV(xt_min, xt, zt) -> GCVMetadata(GaussHermiteCubature(31)) 
hgfmeta (generic function with 1 method)
function run_inference(data, real_k, real_w, z_variance, y_variance)

    autoupdates   = @autoupdates begin
        zt_min_mean, zt_min_var = mean_var(q(zt))
        xt_min_mean, xt_min_var = mean_var(q(xt))

    return infer(
        model         = hgf(real_k, real_w, z_variance, y_variance),
        constraints   = hgfconstraints(),
        meta          = hgfmeta(),
        data          = (y = data, ),
        autoupdates   = autoupdates,
        keephistory   = length(data),
        historyvars    = (
            xt = KeepLast(),
            zt = KeepLast()
        initmarginals = (
            zt = NormalMeanVariance(0.0, 5.0),
            xt = NormalMeanVariance(0.0, 5.0),
        iterations    = 5,
        free_energy   = true,
        autostart     = true,
        callbacks     = (
            after_model_creation = (model, returnval) -> begin 
                gcvnode = returnval
                setmarginal!(gcvnode, :y_x, MvNormalMeanCovariance([ 0.0, 0.0 ], [ 5.0, 5.0 ]))
run_inference (generic function with 1 method)
result = run_inference(y, real_k, real_w, z_variance, y_variance);

mz = result.history[:zt];
mx = result.history[:xt];
    pz = plot(title = "Hidden States Z")
    px = plot(title = "Hidden States X")
    plot!(pz, 1:n, z, label = "z_i", color = :orange)
    plot!(pz, 1:n, mean.(mz), ribbon = std.(mz), label = "estimated z_i", color = :teal)
    plot!(px, 1:n, x, label = "x_i", color = :green)
    plot!(px, 1:n, mean.(mx), ribbon = std.(mx), label = "estimated x_i", color = :violet)
    plot(pz, px, layout = @layout([ a; b ]))

As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the values for Bethe Free Energy functional:

plot(result.free_energy_history, label = "Bethe Free Energy")

As we can see BetheFreeEnergy converges nicely to a stable point.

At final, lets check the overall performance of our resulting Variational Message Passing algorithm:

@benchmark run_inference($y, $real_k, $real_w, $z_variance, $y_variance)
BenchmarkTools.Trial: 45 samples with 1 evaluation.
 Range (min … max):   84.766 ms … 164.193 ms  ┊ GC (min … max): 0.00% … 0.0
 Time  (median):     109.323 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   112.571 ms ±  13.115 ms  ┊ GC (mean ± σ):  1.76% ± 3.3

  ▄▁▁▁▁▁▁▄▁▁▁▁▄▇▄▁▆████▁▁▆▄▆▄▁▁▁▄▁▁▁▁▄▁▄▁▁▁▄▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁
  84.8 ms          Histogram: frequency by time          164 ms <

 Memory estimate: 18.78 MiB, allocs estimate: 456890.