AReaL/arealite/utils/functional.py

180 lines
5.8 KiB
Python

from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
@torch.compile
def _gather_logprobs(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
@torch.compile
def _gather_logprobs_entropy(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels, entropy
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
chunk_size: int = 1024,
):
batch_size = logits.shape[0]
if batch_size <= chunk_size:
return _gather_logprobs(logits, labels, temperature)
log_probs_labels_list = []
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
chunk_logits = logits[i:end_idx]
chunk_labels = labels[i:end_idx]
chunk_log_probs = _gather_logprobs(chunk_logits, chunk_labels, temperature)
log_probs_labels_list.append(chunk_log_probs)
return torch.cat(log_probs_labels_list)
def gather_logprobs_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
chunk_size: int = 1024,
):
batch_size = logits.shape[0]
if batch_size <= chunk_size:
return _gather_logprobs_entropy(logits, labels, temperature)
log_probs_labels_list = []
entropy_list = []
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
chunk_logits = logits[i:end_idx]
chunk_labels = labels[i:end_idx]
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
chunk_logits, chunk_labels, temperature
)
log_probs_labels_list.append(chunk_log_probs)
entropy_list.append(chunk_entropy)
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
@torch.no_grad()
def masked_normalization(
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
dim=None,
unbiased=False,
eps=1e-5,
high_precision=True,
all_reduce=True,
reduce_group=None,
):
dtype = torch.float64 if high_precision else torch.float32
x = x.to(dtype)
if dim is None:
dim = tuple(range(len(x.shape)))
if mask is None:
factor = torch.tensor(
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
)
else:
mask = mask.to(dtype)
x = x * mask
factor = mask.sum(dim, keepdim=True)
x_sum = x.sum(dim=dim, keepdim=True)
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
if dist.is_initialized() and all_reduce:
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(
x_sum_sq,
op=dist.ReduceOp.SUM,
group=reduce_group,
)
mean = x_sum / factor
meansq = x_sum_sq / factor
var = meansq - mean**2
if unbiased:
var *= factor / (factor - 1)
return ((x - mean) / (var.sqrt() + eps)).float()
def ppo_actor_loss_fn(
logprobs: torch.Tensor,
proximal_logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
loss_mask: torch.Tensor,
c_clip: Optional[float] = None,
behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
"""
When decoupled loss is disabled:
1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp;
2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend.
When decoupled loss is enabled, proximal_logprobs is the recomputed logp,
old_logprobs is produced by the inference engine.
"""
loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
clip_mask = pg_loss1.detach() < pg_loss2.detach()
pg_loss = torch.max(pg_loss1, pg_loss2)
if c_clip is not None:
assert c_clip > 1.0, c_clip
pg_loss3 = torch.sign(advantages) * c_clip * advantages
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_mask = (
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
if behav_imp_weight_cap is not None
else loss_mask
)
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach()
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - proximal_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat