Continuous transition node

The ContinuousTransition node encodes a linear (or nonlinear) Gaussian state transition:

\[y \sim \mathcal{N}(K(a) \cdot x, \, W^{-1})\]

It transforms an m-dimensional input vector x into an n-dimensional output vector y via a learned matrix K(a), where a is a latent vector and K is a user-supplied transformation function. The precision matrix W controls the amount of jitter in the transition.

This node is the continuous-state counterpart of DiscreteTransition and the primary building block for Kalman-filter-style state-space models where the transition matrix is uncertain and must be inferred.

Interfaces

InterfaceRole
yn-dimensional output state
xm-dimensional input state
aVector parameterizing the transition matrix via K(a)
Wn×n precision matrix of the transition noise

Specifying the transformation

The transformation K(a) is passed through ContinuousTransitionMeta (alias: CTMeta). The function must return an n×m matrix. For example:

# Unstructured: reshape a length-4 vector into a 2×2 matrix
transformation = a -> reshape(a, 2, 2)

a ~ MvNormalMeanCovariance(zeros(4), Diagonal(ones(4)))
y ~ ContinuousTransition(x, a, W) where { meta = CTMeta(transformation) }

When the matrix has known structure, K(a) can encode it explicitly:

# Rotation matrix parameterized by a single angle
transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])]

a ~ MvNormalMeanCovariance([0.0], [1.0;;])
y ~ ContinuousTransition(x, a, W) where { meta = CTMeta(transformation) }
Note

Even for scalar transitions, a must be a vector (length 1). Use MvNormal rather than Normal for the prior on a.

Factorization constraints

The node supports two factorization assumptions:

Mean-field — all variables are treated as independent:

q(y, x, a, W) = q(y)q(x)q(a)q(W)

Structured — the joint q(y, x) is kept intact (useful for Kalman smoothing):

q(y, x, a, W) = q(y, x)q(a)q(W)

Companion matrix

For autoregressive-style transitions, the companion matrix representation converts an AR coefficient vector into a state transition matrix. See CompanionMatrix in the algebra utilities and the Autoregressive node for a specific application.

ReactiveMP.ContinuousTransitionType

The functional form of the ContinuousTransition node is given by: y ~ Normal(K(a) * x, W⁻¹)

This node transforms an m-dimensional vector x into an n-dimensional vector y via a linear (or nonlinear) transformation with a n×m-dimensional matrix A that is constructed from a vector a via a transformation K(a). ContinuousTransition node is primarily used in two regimes:

When no structure on A is specified:

transformation = a -> reshape(a, 2, 2)
...
a ~ MvNormalMeanCovariance(zeros(2), Diagonal(ones(2)))
y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)}
...

When some structure if A is known:

transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])]
...
a ~ MvNormalMeanCovariance(zeros(1), Diagonal(ones(1)))
y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)}
...

To construct the matrix A, the elements of a are reshaped into A according to the transformation function provided in the meta. If you intend to use univariate Gaussian distributions, use it as a vector of length 1, e.g.a ~ MvNormalMeanCovariance([0.0], [1.;])`.

Check ContinuousTransitionMeta for more details on how to specify the transformation function that must return a matrix.

y ~ ContinuousTransition(x, a, W) where {meta = ContinuousTransitionMeta(transformation)}

Interfaces:

  1. y - n-dimensional output of the ContinuousTransition node.
  2. x - m-dimensional input of the ContinuousTransition node.
  3. a - any-dimensional vector that casts into the matrix A.
  4. W - n×n-dimensional precision matrix used to soften the transition and perform variational message passing.

Note that you can set W to a fixed value or put a prior on it to control the amount of jitter.

The ContinuousTransition node support two factorizations:

  1. Mean-field factorization:
@constraints begin
    q(y, x, a, W) = q(y)q(x)q(a)q(W)
end
  1. Structured factorization:
@constraints begin
    q(y, x, a, W) = q(y, x)q(a)q(W)
end
source
ReactiveMP.ContinuousTransitionMetaType

ContinuousTransitionMeta is used as a metadata flag in ContinuousTransition to define the transformation function for constructing the matrix A from vector a.

ContinuousTransitionMeta requires a transformation function and the length of vector a, which acts as an expansion point for approximating the transformation linearly. If transformation appears to be linear, then no approximation is performed.

Constructors:

  • ContinuousTransitionMeta(transformation::Function, â::Vector{<:Real}): Constructs a ContinuousTransitionMeta struct with the transformation function and allocated basis vectors.

Fields:

  • f: Represents the transformation function that transforms vector a into matrix A

The ContinuousTransitionMeta struct plays a pivotal role in defining how the vector a is transformed into the matrix A, thus influencing the behavior of the ContinuousTransition node.

source
ReactiveMP.ctcompanion_matrixFunction
`ctcompanion_matrix` casts a vector `a` into a matrix `A` by means of linearization of the transformation function `f` around the expansion point `a0`.
source