Source code for rlightning.policy.rsl_rl_policy.ppo_vtrace

"""PPO policy variant using V-trace value estimation.

Overrides rsl_rl's PPO to compute value targets and advantages with V-trace corrections.
"""

from typing import Any, Dict

import torch

from rlightning.policy.utils import vtrace

# Try to import rsl_rl components; if unavailable, define placeholders.
try:
    from rsl_rl.algorithms import PPO
    from rsl_rl.storage.storage import Dataset
    from rsl_rl.utils.benchmarkable import Benchmarkable
    from rsl_rl.utils.recurrency import (
        trajectories_to_transitions,
        transitions_to_trajectories,
    )

    HAS_RSL_RL = True
except ImportError:
    HAS_RSL_RL = False
    PPO = object
    Dataset = Any

[docs] class Benchmarkable:
[docs] @staticmethod def register(func): return func
[docs] class PPOVtrace(PPO): def __init__( self, env, rho_bar: float = 1.0, c_bar: float = 1.0, **kwargs, ): if not HAS_RSL_RL: raise ImportError("PPOVtrace requires 'rsl_rl' to be installed. " "Please install it to use this policy.") super().__init__(env, **kwargs) self._rho_bar = rho_bar self._c_bar = c_bar self._register_serializable("_rho_bar", "_c_bar") @Benchmarkable.register def _process_dataset(self, dataset: Any) -> Any: """Override PPO dataset processing to apply V-trace value estimation.""" rewards = torch.stack([entry["rewards"] for entry in dataset]) dones = torch.stack([entry["dones"] for entry in dataset]) values = torch.stack([entry["values"] for entry in dataset]) if hasattr(dataset[0], "next_values"): next_values = torch.stack([entry["values"] for entry in dataset]) else: critic_kwargs = ( {"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {} ) final_values = self.critic.forward(dataset[-1]["next_critic_observations"], **critic_kwargs) next_values = torch.cat((values[1:], final_values.unsqueeze(0)), dim=0) behavior_logp = torch.stack([entry["actions_logp"] for entry in dataset]) actions = torch.stack([entry["actions"] for entry in dataset]) actor_obs = torch.stack([entry["actor_observations"] for entry in dataset]) if self.recurrent: transition_obs = actor_obs.reshape(*actor_obs.shape[:2], -1) actor_state_h = torch.stack([entry["actor_state_h"] for entry in dataset]) actor_state_c = torch.stack([entry["actor_state_c"] for entry in dataset]) observations, data = transitions_to_trajectories(transition_obs, dones) hidden_state_h, _ = transitions_to_trajectories(actor_state_h, dones) hidden_state_c, _ = transitions_to_trajectories(actor_state_c, dones) hidden_state = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1)) action_mean, action_std = self.actor.forward(observations, hidden_state=hidden_state, compute_std=True) action_mean = action_mean.reshape(*observations.shape[:-1], self._action_size) action_std = action_std.reshape(*observations.shape[:-1], self._action_size) action_mean = trajectories_to_transitions(action_mean, data) action_std = trajectories_to_transitions(action_std, data) else: action_mean, action_std = self.actor.forward(actor_obs.flatten(0, 1), compute_std=True) dist = torch.distributions.Normal(action_mean, action_std) current_logp = dist.log_prob(actions.flatten(0, 1)).sum(-1).reshape(actor_obs.shape[:2]) log_rhos = current_logp - behavior_logp if "timeouts" in dataset[0]: timeouts = torch.stack([entry["timeouts"] for entry in dataset]) rewards += self.gamma * timeouts * values vs, advantages = vtrace.vtrace_correction( rewards, values, next_values, dones, log_rhos, self.gamma, self._rho_bar, self._c_bar ) amean, astd = advantages.mean(), torch.nan_to_num(advantages.std()) for step in range(len(dataset)): dataset[step]["target_value"] = vs[step] dataset[step]["advantages"] = advantages[step] dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8) return dataset @Benchmarkable.register def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: if self.recurrent: observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"]) hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"]) hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"]) hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1)) trajectory_evaluations = self.critic.forward(observations, hidden_state=hidden_states) trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1]) evaluation = trajectories_to_transitions(trajectory_evaluations, data) else: evaluation = self.critic.forward(batch["critic_observations"]) value_clipped = batch["values"] + (evaluation - batch["values"]).clamp(-self._clip_ratio, self._clip_ratio) returns = batch["target_value"] value_losses = (evaluation - returns).pow(2) value_losses_clipped = (value_clipped - returns).pow(2) value_loss = torch.max(value_losses, value_losses_clipped).mean() return value_loss