Deep SPI

Safe Policy Improvement via World Models

Implementation of the techniques presented in our paper Deep SPI: Safe Policy Improvement via World Models.

The Core Problem

When you train a deep RL policy with auxiliary losses to improve the representation (the observation encoder before your value/critic heads), you face a critical timing problem: The representation optimized under your behavioral policy may not be reliable for the next policy. Add auxiliary losses to regularize the latent space, improve the policy, and suddenly the encoder has shifted, invalidating the very representation you were relying on.

Deep SPI solves this by coupling world-model learning with controlled policy updates: you improve the policy step by step, in a neighborhood that keeps it close to regions where the world model is well-calibrated. This way, updates that look good in the model actually translate to improvements in the real environment.

How It Works: The Algorithm

Deep SPI operates on a deep RL agent where the learned representation $\phi$ (the encoder) maps raw observations into a compact latent space. This representation is shared by both the policy and the world model—the policy predicts actions from $\phi(s)$, and the world model predicts rewards and transitions from $\phi(s)$. The core insight is that as the policy improves, this shared representation can shift, breaking the world model’s calibration. Deep SPI solves this by constraining policy updates to stay in a neighborhood where the world model remains reliable.

Learning a Reliable World Model

Learn two predictors $\overline{R}$ (reward) and $\overline{P}$ (transition) on behavioral data, operating on the learned representation $\phi(s)$:

$$ L_R := \mathbb{E}_{(s,a,r)\sim \pi_b}\left|r - \bar R(\phi(s), a)\right|. $$ $$ L_P := \mathbb{E}_{(s,a,s')\sim \pi_b}\left[\mathbb{E}_{\overline{s}^+ \sim \bar P(\cdot \mid \phi(s), a)}\|\phi(s') - \overline{s}^+\|\right]. $$

These measure calibration exactly where you have data, telling you whether the model is reliable for the next policy update.

Enforcing a Policy Neighborhood

Restrict new policies to stay close to the current policy: action probabilities can’t deviate more than a fixed ratio (similar to PPO’s clipping). This keeps updates close to the data regime, ensuring the world model remains trustworthy.

The Core Guarantee

With constrained updates and small prediction losses, improvements in the learned model transfer back to the real environment. The gap between model and reality shrinks as the prediction losses shrink — so if you’re predicting rewards and transitions well, your planned improvements will actually happen in the real world.

Representation Quality

The encoder remains useful for the next policy: states with very different values can’t collapse to the same latent point without breaking the guarantee. This prevents representation degradation and keeps the learned state abstraction useful as you improve the policy.

Implementation: PPO + World Models

The practical question: how do you enforce the neighborhood in a modern deep RL pipeline?

The answer is PPO, because its clipped objective already acts like a trust region. We build upon Clean RL implementations for robust JAX-based training. But there’s a critical subtlety: auxiliary losses (like $L_R$ and $L_P$) update the encoder, so they can indirectly push the composed policy $\pi := \bar{\pi} \circ \phi$ outside the neighborhood even if the policy head looks well-behaved.

The fix: fold the auxiliary terms into the advantage function $A(s, a)$:

$$ U(s, a) := A(s, a) - \alpha_R \cdot L_R - \alpha_P \cdot L_P $$

Then optimize a PPO-like objective. This way, the encoder and policy head move together in a way that respects the neighborhood constraint.

The Data Coverage Problem

There’s a critical practical detail: PPO is on-policy, which means world-model learning can suffer from narrow, correlated coverage if you collect long rollouts from only a few environments. That’s exactly the wrong regime for learning a well-calibrated world model.

The solution: use many vectorized JAX environments with shorter rollouts instead of a few environments with long rollouts. This gives much broader state coverage while staying on-policy—crucial because a local guarantee is only helpful if your local data actually cover enough of the neighborhood to learn the model well.

Implementation Details

  • Framework: JAX-based, for efficient jitted computation and GPU scaling
  • Vectorized environments: many parallel JAX-jitted environments with shorter rollouts for better state coverage
  • Reproducibility: checked-in uv.lock for exact dependency reproducibility
  • Benchmarks: ALE-57 (57 Atari games), with results matching or exceeding PPO and DeepMDPs
  • Environments: uses envpool for high-throughput Atari simulation with GPU support
  • Code structure: modular design separating encoder, policy, world model, and PPO training logic

Why It Matters

Most auxiliary-loss approaches (even VAE-MDPs and related work) derive their guarantees on-policy, meaning they are valid only for the policy that collected the data. Deep SPI extends this to show that with controlled updates, the guarantees survive the policy change. This is especially valuable because:

  1. Monotonic improvement: You can reason about whether each update actually helps, not just whether it looks good in hindsight
  2. Representation stability: The encoder doesn’t degrade as you improve the policy—it stays useful for downstream learning
  3. No offline requirement: Deep SPI works in online settings; you don’t need a pre-collected dataset

Results

On ALE-57, Deep SPI matches or exceeds PPO and DeepMDPs while maintaining explicit monotonic improvement bounds. The algorithm demonstrates that principled methods with guarantees and practical performance need not be in tension.

_ALE-57 Results: Deep SPI is competitive across multiple metrics. Higher is better for IQM (interquartile mean), median, and mean returns; lower is better for optimality gap._
ALE-57 Results: Deep SPI is competitive across multiple metrics. Higher is better for IQM (interquartile mean), median, and mean returns; lower is better for optimality gap.
_Sample efficiency: Deep SPI's wall-clock learning curves match strong baselines, showing the guarantees don't impose significant computational overhead._
Sample efficiency: Deep SPI’s wall-clock learning curves match strong baselines, showing the guarantees don’t impose significant computational overhead.

Learn More

Related