This example has been auto-generated from the examples/
folder at GitHub repository.
Chance-Constrained Active Inference
This notebook applies reactive message passing for active inference in the context of chance-constraints. The implementation is based on (van de Laar et al., 2021, "Chance-constrained active inference") and discussion with John Boik.
We consider a 1-D agent that tries to elevate itself above ground level. Instead of a goal prior, we impose a chance constraint on future states, such that the agent prefers to avoid the ground with a preset probability (chance) level.
using Pkg; Pkg.activate(".."); Pkg.instantiate();
using Plots, Distributions, StatsFuns, RxInfer
Pkg.status()
Status `~/work/RxInfer.jl/RxInfer.jl/docs/src/examples/Project.toml`
[b4ee3484] BayesBase v1.5.0
[6e4b80f9] BenchmarkTools v1.5.0
[336ed68f] CSV v0.10.15
[a93c6f00] DataFrames v1.7.0
[31c24e10] Distributions v0.25.113
[62312e5e] ExponentialFamily v1.6.0
⌃ [587475ba] Flux v0.14.25
[38e38edf] GLM v1.9.0
[b3f8163a] GraphPPL v4.3.4
[34004b35] HypergeometricFunctions v0.3.25
[7073ff75] IJulia v1.26.0
[4138dd39] JLD v0.13.5
[b964fa9f] LaTeXStrings v1.4.0
[429524aa] Optim v1.10.0
⌅ [3bd65402] Optimisers v0.3.4
[d96e819e] Parameters v0.12.3
[91a5bcdd] Plots v1.40.9
[92933f4c] ProgressMeter v1.10.2
[a194aa59] ReactiveMP v4.4.5
[37e2e3b7] ReverseDiff v1.15.3
[df971d30] Rocket v1.8.1
[86711068] RxInfer v3.7.2 `~/work/RxInfer.jl/RxInfer.jl`
[276daf66] SpecialFunctions v2.4.0
[860ef19b] StableRNGs v1.0.2
[4c63d2b9] StatsFuns v1.3.2
[f3b207a7] StatsPlots v0.15.7
[44d3d7a6] Weave v0.10.12
[fdbf4ff8] XLSX v0.10.4
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃
may be upgradable, but those with ⌅ are restricted by compatibility constr
aints from upgrading. To see why use `status --outdated`
Chance-Constraint Node Definition
A chance-constraint is meant to constraint a marginal distribution to abide by certain properties. In this case, a (posterior) probability distribution should not "overflow" a given region by more than a certain probability mass. This constraint then affects adjacent beliefs and ultimately the controls to (hopefully) account for the imposed constraint.
In order to enforce this constraint on a marginal distribution, an auxiliary chance-constraint node is included in the graphical model. This node then sends messages that enforce the marginal to abide by the preset conditions. In other words, the (chance) constraint on the (posterior) marginal, is converted to a prior constraint on the generative model that sends an adaptive message. We start by defining this chance-constraint node and its message.
struct ChanceConstraint end
# Node definition with safe region limits (lo, hi), overflow chance epsilon and tolerance atol
@node ChanceConstraint Stochastic [out, lo, hi, epsilon, atol]
# Function to compute normalizing constant and central moments of a truncated Gaussian distribution
function truncatedGaussianMoments(m::Float64, V::Float64, a::Float64, b::Float64)
V = clamp(V, tiny, huge)
StdG = Distributions.Normal(m, sqrt(V))
TrG = Distributions.Truncated(StdG, a, b)
Z = Distributions.cdf(StdG, b) - Distributions.cdf(StdG, a) # safe mass for standard Gaussian
if Z < tiny
# Invalid region; return undefined mean and variance of truncated distribution
Z = 0.0
m_tr = 0.0
V_tr = 0.0
else
m_tr = Distributions.mean(TrG)
V_tr = Distributions.var(TrG)
end
return (Z, m_tr, V_tr)
end;
@rule ChanceConstraint(:out, Marginalisation) (
m_out::UnivariateNormalDistributionsFamily, # Require inbound message
q_lo::PointMass,
q_hi::PointMass,
q_epsilon::PointMass,
q_atol::PointMass) = begin
# Extract parameters
lo = mean(q_lo)
hi = mean(q_hi)
epsilon = mean(q_epsilon)
atol = mean(q_atol)
(m_bw, V_bw) = mean_var(m_out)
(xi_bw, W_bw) = (m_bw, 1. /V_bw) # check division by zero
(m_tilde, V_tilde) = (m_bw, V_bw)
# Compute statistics (and normalizing constant) of q in safe region G
# Phi_G is called the "safe mass"
(Phi_G, m_G, V_G) = truncatedGaussianMoments(m_bw, V_bw, lo, hi)
xi_fw = xi_bw
W_fw = W_bw
if epsilon <= 1.0 - Phi_G # If constraint is active
# Initialize statistics of uncorrected belief
m_tilde = m_bw
V_tilde = V_bw
for i = 1:100 # Iterate at most this many times
(Phi_lG, m_lG, V_lG) = truncatedGaussianMoments(m_tilde, V_tilde, -Inf, lo) # Statistics for q in region left of G
(Phi_rG, m_rG, V_rG) = truncatedGaussianMoments(m_tilde, V_tilde, hi, Inf) # Statistics for q in region right of G
# Compute moments of non-G region as a mixture of left and right truncations
Phi_nG = Phi_lG + Phi_rG
m_nG = Phi_lG / Phi_nG * m_lG + Phi_rG / Phi_nG * m_rG
V_nG = Phi_lG / Phi_nG * (V_lG + m_lG^2) + Phi_rG/Phi_nG * (V_rG + m_rG^2) - m_nG^2
# Compute moments of corrected belief as a mixture of G and non-G regions
m_tilde = (1.0 - epsilon) * m_G + epsilon * m_nG
V_tilde = (1.0 - epsilon) * (V_G + m_G^2) + epsilon * (V_nG + m_nG^2) - m_tilde^2
# Re-compute statistics (and normalizing constant) of corrected belief
(Phi_G, m_G, V_G) = truncatedGaussianMoments(m_tilde, V_tilde, lo, hi)
if (1.0 - Phi_G) < (1.0 + atol)*epsilon
break # Break the loop if the belief is sufficiently corrected
end
end
# Convert moments of corrected belief to canonical form
W_tilde = inv(V_tilde)
xi_tilde = W_tilde * m_tilde
# Compute canonical parameters of forward message
xi_fw = xi_tilde - xi_bw
W_fw = W_tilde - W_bw
end
return NormalWeightedMeanPrecision(xi_fw, W_fw)
end
Definition of the Environment
We consider an environment where the agent has an elevation level, and where the agent directly controls its vertical velocity. After some time, an unexpected and sudden gust of wind tries to push the agent to the ground.
wind(t::Int64) = -0.1*(60 <= t < 100) # Time-dependent wind profile
function initializeWorld()
x_0 = 0.0 # Initial elevation
x_t_last = x_0
function execute(t::Int64, a_t::Float64)
x_t = x_t_last + a_t + wind(t) # Update elevation
x_t_last = x_t # Reset state
return x_t
end
x_t = x_0 # Predefine outcome variable
observe() = x_t # State is fully observed
return (execute, observe)
end;
Generative Model for Regulator
We consider a fully observed Markov decision process, where the agent directly observes the true state (elevation) of the world. In this case we only need to define a chance-constrained generative model of future states. Inference for controls on this model then derives our controller.
# m_u ::Vector{Float64}, , Control prior means
# v_u = datavar(Float64, T) Control prior variances
# x_t ::Float64 Fully observed state
@model function regulator_model(T, m_u, v_u, x_t, lo, hi, epsilon, atol)
# Loop over horizon
x_k_last = x_t
for k = 1:T
u[k] ~ NormalMeanVariance(m_u[k], v_u[k]) # Control prior
x[k] ~ x_k_last + u[k] # Transition model
x[k] ~ ChanceConstraint(lo, hi, epsilon, atol) where { # Simultaneous constraint on state
dependencies = RequireMessageFunctionalDependencies(out = NormalWeightedMeanPrecision(0, 0.01))} # Predefine inbound message to break circular dependency
x_k_last = x[k]
end
end
Reactive Agent Definition
function initializeAgent()
# Set control prior statistics
m_u = zeros(T)
v_u = lambda^(-1)*ones(T)
function compute(x_t::Float64)
model_t = regulator_model(;T=T, lo=lo, hi=hi, epsilon=epsilon, atol=atol)
data_t = (m_u = m_u, v_u = v_u, x_t = x_t)
result = infer(
model = model_t,
data = data_t,
iterations = n_its)
# Extract policy from inference results
pol = mode.(result.posteriors[:u][end])
return pol
end
pol = zeros(T) # Predefine policy variable
act() = pol[1]
return (compute, act)
end;
Action-Perception Cycle
Next we define and execute the action-perception cycle. Because the state is fully observed, these is no slide (estimator) step in the cycle.
# Simulation parameters
N = 160 # Total simulation time
T = 1 # Lookahead time horizon
lambda = 1.0 # Control prior precision
lo = 1.0 # Chance region lower bound
hi = Inf # Chance region upper bound
epsilon = 0.01 # Allowed chance violation
atol = 0.01 # Convergence tolerance for chance constraints
n_its = 10; # Number of inference iterations
(execute, observe) = initializeWorld() # Let there be a world
(compute, act) = initializeAgent() # Let there be an agent
a = Vector{Float64}(undef, N) # Actions
x = Vector{Float64}(undef, N) # States
for t = 1:N
a[t] = act()
execute(t, a[t])
x[t] = observe()
compute(x[t])
end
Results
Results show that the agent does not allow the wind to push it all the way to the ground.
p1 = plot(1:N, wind.(1:N), color="blue", label="Wind", ylabel="Velocity", lw=2)
plot!(p1, 1:N, a, color="red", label="Control", lw=2)
p2 = plot(1:N, x, color="black", lw=2, label="Agent", ylabel="Elevation")
plot(p1, p2, layout=(2,1))