Discrete transition node

The DiscreteTransition node encodes a Markov state transition for discrete categorical variables. It represents the conditional distribution:

\[p(\text{out} \mid \text{in}, A) = A \cdot \text{in}\]

where out and in are categorical (discrete) state variables and A is a column-stochastic transition matrix (each column sums to one). This is the fundamental building block for Hidden Markov Models (HMMs) and other discrete state-space models.

Interfaces

The DiscreteTransition node accepts a variable number of inputs:

Interface indexAliasRole
1outNext (output) state
2inCurrent (input) state
3aTransition matrix variable
4, 5, …T1, T2, …Optional additional transition matrices

This flexible interface allows multi-dimensional transition structures where the full transition is a product of several matrices.

Comparison with a plain Categorical node

A plain Categorical node fixes the probability vector at the time the node is created. DiscreteTransition is different in two important ways:

  1. The transition matrix a is a variable — it can have a DirichletCollection prior and its posterior is inferred jointly with the states.
  2. The input state is also a variable — messages flow in both directions, making it possible to infer both past states (smoothing) and future states (prediction).

Typical usage pattern

# prior on initial state
s[1] ~ Categorical(fill(1/K, K))

# prior on transition matrix (one Dirichlet per column)
A ~ DirichletCollection(ones(K, K))

# Markov chain
for t in 2:T
    s[t] ~ DiscreteTransition(s[t-1], A)
end

# emission likelihoods
for t in 1:T
    y[t] ~ Categorical(B * s[t])
end

Utility functions

The following internal functions implement the message update rules for the DiscreteTransition node. They are exposed for users who want to reuse them in custom rule definitions.

ReactiveMP.discrete_transition_decode_marginalFunction
discrete_transition_decode_marginal(marginal_name::String, marginal::Contingency{T, <:AbstractArray{T, N}}) where {T, N}

Decode the marginal distribution into a tuple of dimensions and a probability vector.

Arguments

  • marginal_name: The name of the marginal distribution.
  • marginal: The marginal distribution.

For example, if the marginal distribution is "int1t5", we know that "in" corresponds to dimension 2, and "t1" and "t5" correspond to dimensions 3 and 7 of the contingency tensor. Therefore, the function will return (2, 3, 7) and the contingency tensor marginal.p.

source
ReactiveMP.discrete_transition_marginal_ruleFunction
discrete_transition_marginal_rule(message_names, messages, marginals_names, marginals, q_a)

Compute the marginal for one of the Categorical interfaces of the DiscreteTransition node. This function is similar to discrete_transition_structured_message_rule but uses multiply_dimensions instead of sum_out_dimensions for the messages.

Arguments

  • message_names: The names of the incoming messages. These are the variables in the same factorization cluster as the variable over which we are computing the message.
  • messages: The incoming messages. These are guaranteed to be either Categorical, Bernoulli or PointMass distributions.
  • marginals_names: The names of the other marginal distributions attached to the DiscreteTransition node. These marginal distributions are not in the same factorization cluster as the variable over which we are computing the message.
  • marginals: The incoming marginals. These are guaranteed to be either Contingency, Categorical, Bernoulli or PointMass distributions.
  • q_a: The marginal distribution over the transition tensor.
source
ReactiveMP.discrete_transition_process_marginalsFunction
discrete_transition_process_marginals(e_log_a, marginals_names, marginals)

Process the marginals to update the expected log transition matrix. This is a common operation used by both discrete_transition_structured_message_rule and discrete_transition_marginal_rule.

Arguments

  • e_log_a: The expected log of the transition matrix.
  • marginals_names: The names of the marginal distributions.
  • marginals: The marginal distributions.

Returns

  • The updated expected log transition matrix.
source
ReactiveMP.multiply_dimensions!Function
multiply_dimensions(tensor::AbstractArray{T, M}, dims::NTuple{N, Int}, values::AbstractArray{T, N}) where {T, M, N}

Multiply the tensor with the values along the specified dimensions. This is similar to sum_out_dimensions but doesn't sum the result, only performs the elementwise multiplication.

source
ReactiveMP.sum_out_dimensionsFunction
sum_out_dimensions(tensor::AbstractArray{T, M}, dims::NTuple{N, Int}, values::AbstractArray{T, N}) where {T, M, N}

Sum out the dimensions of the tensor that are not part of the marginal distribution. This is a generalization of an inner product, where we also figure out which dimensions of the tensor align with the dimensions of values.

source
ReactiveMP.discrete_transition_process_messagesFunction
discrete_transition_process_messages(e_log_a, message_names, messages, callback)

Process the messages to update the expected log transition matrix. This is a common operation used by both discrete_transition_structured_message_rule and discrete_transition_marginal_rule. The callback function is used to update the expected log transition matrix. This argument toggles between marginalising out a variable (for messages) and computing a joint marginal distribution.

source
ReactiveMP.discrete_transition_structured_message_ruleFunction
discrete_transition_structured_message_rule(message_names, messages, marginals_names, marginals, q_a)

Compute the message for one of the Categorical interfaces of the DiscreteTransition node. This function 1. Computes the expected log of the transition matrix e_log_a 2. For every incoming marginal distribution, it determines which dimension of the contingency tensor it corresponds to. 3. It then uses sum_out_dimensions to to compute the inner product of e_log_a with the marginal distribution along the specified dimension. 4. The result of this is the VMP message, which we have to exponentiate and multiply with the incoming messages. 5. The result is then normalized to sum to 1.

Arguments

  • message_names: The names of the incoming messages. These are the variables in the same factorization cluster as the variable over which we are computing the message.
  • messages: The incoming messages. These are guaranteed to be either Categorical, Bernoulli or PointMass distributions.
  • marginals_names: The names of the other marginal distributions attached to the DiscreteTransition node. These marginal distributions are not in the same factorization cluster as the variable over which we are computing the message.
  • marginals: The incoming marginals. These are guaranteed to be either Contingency, Categorical, Bernoulli or PointMass distributions.
  • q_a: The marginal distribution over the transition tensor.
source