Source code for rlightning.buffer.utils.utils

"""Default preprocessing utilities for buffer transitions."""

from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

from rlightning.types import EnvRet, PolicyResponse, Processed_EnvRet_fields

from .preprocessors import (
    Preprocessor,
    default_obs_preprocessor,
    default_reward_preprocessor,
)


[docs] def default_env_ret_preprocess_fn( transition_dict: Dict[str, Any], env_ret: EnvRet, obs_preprocessor: Preprocessor, reward_preprocessor: Preprocessor, ) -> Dict[str, Any]: """Populate a transition dict from an EnvRet.""" if not isinstance(env_ret, EnvRet): raise TypeError(f"env_ret must be an instance of EnvRet, got {type(env_ret)}") env_ret_dict = env_ret.to_dict() # preprocess env_ret_dict["observation"] = obs_preprocessor(env_ret.observation) env_ret_dict["last_reward"] = reward_preprocessor(env_ret.last_reward) for key, value in env_ret_dict.items(): transition_dict[key] = value return transition_dict
[docs] def default_policy_resp_preprocess_fn( transition_dict: Dict[str, Any], policy_resp: PolicyResponse, ) -> Dict[str, Any]: """Populate a transition dict from a PolicyResponse.""" if not isinstance(policy_resp, PolicyResponse): raise TypeError(f"policy_resp must be an instance of PolicyResponse, got {type(policy_resp)}") policy_resp_dict = policy_resp.to_dict() for key, value in policy_resp_dict.items(): transition_dict[key] = value return transition_dict
[docs] def default_preprocess_fn( transition_dict: Dict[str, Any], env_ret: Optional[EnvRet] = None, policy_resp: Optional[PolicyResponse] = None, obs_preprocessor: Optional[Preprocessor] = default_obs_preprocessor, reward_preprocessor: Optional[Preprocessor] = default_reward_preprocessor, env_ret_preprocess_fn: Optional[Callable] = default_env_ret_preprocess_fn, policy_resp_preprocess_fn: Optional[Callable] = default_policy_resp_preprocess_fn, ) -> Dict[str, Any]: """ Default transition preprocess function. It will use the given obs_preprocessor and reward_preprocessor to preprocess both env_ret and policy_resp, or either one of them. When adding transition in a sync manner (env_ret and policy_resp are paired in one step of rollout), both env_ret and policy_resp should be provided. When adding transition in an async manner, only one of them should be provided. Args: transition_dict: The dict to aggregate transition data from env_ret and policy_resp. env_ret (Optional[EnvRet]): The environment return to be preprocessed. policy_resp (Optional[PolicyResponse]): The policy response to be preprocessed. obs_preprocessor (Optional[Preprocessor]): The preprocessor for observations. reward_preprocessor (Optional[Preprocessor]): The preprocessor for rewards. env_ret_preprocess_fn (Optional[Callable]): Function to preprocess `env_ret`. policy_resp_preprocess_fn (Optional[Callable]): Function to preprocess `policy_resp`. Returns: The preprocessed transition dict. """ if env_ret is None and policy_resp is None: raise ValueError("At least one of env_ret or policy_resp must be provided.") if env_ret is not None and policy_resp is not None: if env_ret.env_id != policy_resp.env_id: raise ValueError( f"Mismatched env_id in env_ret and policy_resp, got {env_ret.env_id} and " f"{policy_resp.env_id}" ) if env_ret is not None: transition_dict = env_ret_preprocess_fn(transition_dict, env_ret, obs_preprocessor, reward_preprocessor) if policy_resp is not None: transition_dict = policy_resp_preprocess_fn(transition_dict, policy_resp) return transition_dict
[docs] def default_postprocess_fn(episode_buffer: Dict[str, List[Any]]) -> Dict[str, Any]: """Convert an episode buffer into a flat training-ready dict.""" data = {} for k, v in episode_buffer.items(): # skip info by default if "info" in k: continue # process keys with "last_" prefix if isinstance(v[0], torch.Tensor): v = torch.stack(v, dim=0) else: v = torch.tensor(v) if k.startswith("last_"): _k = k[5:] _v = v[1:] # support for both 1D and 2D arrays (vector env) data[_k] = _v else: _k, _v = k, v # process special keys if k == "observation": data["next_observation"] = v[1:] data["observation"] = v[:-1] # process policy_resp keys env_fields = set(EnvRet.fields() + Processed_EnvRet_fields) policy_fields = [k for k in episode_buffer.keys() if k not in env_fields] if k in policy_fields: data[k] = v[:-1] return data
[docs] def default_compute_gae( rewards: torch.Tensor, values: torch.Tensor, next_values: torch.Tensor, dones: torch.Tensor, gamma: float, lam: float, normalize_adv: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Generalized Advantage Estimation (GAE). This version uses an explicit loop over the batch dimension. Args: rewards (torch.Tensor): Rewards at each timestep values (torch.Tensor): Value function estimates at each timestep next_values (torch.Tensor): Value function estimates at the next timestep dones (torch.Tensor): Done flags at each timestep gamma (float): Discount factor lam (float): GAE lambda parameter normalize_adv (bool): Whether to normalize the advantages Returns: Tuple[torch.Tensor, torch.Tensor]: Computed advantages and returns """ B, N = rewards.shape # B=batch_size, N=num_envs device = rewards.device advantages = torch.zeros_like(rewards, device=device) last_advantage = torch.zeros(N, device=device) # [num_envs,] each env has independent last_adv # iterate backwards over batch dimension for t in reversed(range(B)): # when done = True, (1 - dones[t])=0, cutting off the accumulation chain. delta = rewards[t] + gamma * next_values[t] * (1.0 - dones[t]) - values[t] # compute advantages[t] for all envs in batch, 2nd term is 0 if done=True advantages[t] = delta + gamma * lam * (1.0 - dones[t]) * last_advantage # update last_advantage for next step in reverse order last_advantage = advantages[t] returns = advantages + values if normalize_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return advantages, returns
[docs] def default_gae_no_loop( rewards: torch.Tensor, values: torch.Tensor, next_values: torch.Tensor, dones: torch.Tensor, gamma: float, lam: float, normalize_adv: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Generalized Advantage Estimation (GAE). This version uses matrix operations to eliminate explicit loops for efficiency. Args: rewards (torch.Tensor): Rewards at each timestep values (torch.Tensor): Value function estimates at each timestep next_values (torch.Tensor): Value function estimates at the next timestep dones (torch.Tensor): Done flags at each timestep gamma (float): Discount factor lam (float): GAE lambda parameter normalize_adv (bool): Whether to normalize the advantages Returns: Tuple[torch.Tensor, torch.Tensor]: Computed advantages and returns """ B, N = rewards.shape # B=batch_size, N=num_envs device = rewards.device coeff = gamma * lam deltas = rewards + gamma * next_values * (1.0 - dones.float()) - values # shape [B, N] discount = coeff * (1.0 - dones) # [B, N],done=True indicates discount=0 # generate reversed cumulative product mask: cumprod from back to front, # done=True makes all subsequent values 0. discount_mask = torch.cat([torch.ones(1, N, device=device), discount[:-1]], dim=0) discount_cum = torch.cumprod(discount_mask.flip(0), dim=0).flip(0) # [B, N] # generate weight matrix and compute advantages weight = coeff ** torch.arange(B, device=device).view(B, 1) # [B, 1] weight_matrix = weight.unsqueeze(1) * discount_cum # [B, N] advantages = torch.matmul(deltas.T, weight_matrix).T # [B, N] returns = advantages + values if normalize_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return advantages, returns