AReaL/realhf/base/stats_tracker.py

255 lines
9.0 KiB
Python

from collections import defaultdict
from enum import Enum, auto
from typing import Dict
import torch
import torch.distributed as dist
class ReduceType(Enum):
AVG = auto()
SUM = auto()
MIN = auto()
MAX = auto()
SCALAR = auto()
MOE_AUX_LOSSES = {}
class DistributedStatsTracker:
def __init__(self, name: str = ""):
self.scope_stack = []
if name:
self.scope_stack.append(name.strip("/"))
self.denominators = {} # key -> denominator key
self.reduce_types = {} # key -> ReduceType
self.stats = defaultdict(list)
def scope(self, name):
"""Context manager for hierarchical scoping"""
return self.Scope(self, name)
class Scope:
def __init__(self, tracker, name):
self.tracker = tracker
self.name = name.strip("/")
def __enter__(self):
self.tracker.scope_stack.append(self.name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.tracker.scope_stack.pop()
def _get_full_key(self, key):
"""Combine scope stack with current key"""
if not self.scope_stack:
return key
return "/".join(self.scope_stack + [key])
def denominator(self, **kwargs):
for key, value in kwargs.items():
if not isinstance(value, torch.Tensor) or value.dtype != torch.bool:
raise ValueError(
f"`{key}` must be a pytorch bool tensor: {value.dtype}"
)
if value.numel() == 0:
raise ValueError(f"`{key}` must be non-empty")
full_key = self._get_full_key(key)
self._set_reduce_type(full_key, ReduceType.SUM)
self.stats[full_key].append(value.detach().clone())
def scalar(self, **kwargs):
for key, value in kwargs.items():
full_key = self._get_full_key(key)
self._set_reduce_type(full_key, ReduceType.SCALAR)
self.stats[full_key].append(float(value))
def stat(
self,
denominator: str,
reduce_type: ReduceType | None = None,
**kwargs,
):
"""Record multiple values from a dictionary"""
for key, value in kwargs.items():
if not isinstance(value, torch.Tensor) or value.dtype != torch.float:
raise ValueError(
f"`{key}` should be a pytorch float tensor: {value.dtype}"
)
if value.numel() == 0:
raise ValueError(f"`{key}` should be non-empty")
if reduce_type == ReduceType.SCALAR:
raise ValueError("Cannot use the scalar reduce type for a tensor")
full_key = self._get_full_key(key)
denorm = self._get_full_key(denominator)
if denorm not in self.stats or not self.stats[denorm]:
raise ValueError(f"Denominator `{denorm}` does not exist")
for x, y in zip(self.stats[denorm], self.stats[full_key] + [value]):
assert x.shape == y.shape, (x.shape, y.shape)
self.denominators[full_key] = denorm
if reduce_type is not None:
self._set_reduce_type(full_key, reduce_type)
self.stats[full_key].append(value.detach().clone())
def _set_reduce_type(self, key, reduce_type):
if not isinstance(reduce_type, ReduceType):
raise ValueError("reduce_type must be a ReduceType enum")
self.reduce_types[key] = reduce_type
def export(self, key=None, reduce_group=None, reset=True) -> Dict[str, float]:
"""Get aggregated statistics"""
self._amend_moe_losses()
if reduce_group is None:
try:
from realhf.base.constants import data_parallel_group
reduce_group = data_parallel_group()
except:
pass
if key is not None:
full_key = self._get_full_key(key)
result = self._aggregate(full_key, reduce_group)
if reset:
if full_key in self.denominators:
self.denominators.pop(full_key)
if full_key in self.reduce_types:
self.denominators.pop(full_key)
self.stats.pop(full_key)
return result
results = {}
for key in list(self.stats.keys()):
results.update(self._aggregate(key, reduce_group))
if reset:
self.denominators = {}
self.reduce_types = {}
self.stats = defaultdict(list)
results = {
k: v.cpu().item() if torch.is_tensor(v) else v for k, v in results.items()
}
return results
def _amend_moe_losses(self):
from realhf.base.constants import is_last_pipe_stage, pipe_parallel_group
global MOE_AUX_LOSSES
mean_losses = {}
for k, loss in MOE_AUX_LOSSES.items():
dist.all_reduce(loss, group=pipe_parallel_group())
mean_losses[k] = float(loss.mean()) # average over layers
MOE_AUX_LOSSES.clear()
if mean_losses and is_last_pipe_stage():
self.scalar(**mean_losses)
def _aggregate(self, key, reduce_group):
if key not in self.stats or not self.stats[key]:
return {}
reduce_type = self.reduce_types.get(key, None)
result = {}
if reduce_type is None:
result["/".join([key, "avg"])] = self._avg_of(key, reduce_group)
result["/".join([key, "min"])] = self._min_of(key, reduce_group)
result["/".join([key, "max"])] = self._max_of(key, reduce_group)
elif reduce_type == ReduceType.AVG:
result[key] = self._avg_of(key, reduce_group)
elif reduce_type == ReduceType.SUM:
result[key] = self._sum_of(key, reduce_group)
elif reduce_type == ReduceType.MIN:
result[key] = self._min_of(key, reduce_group)
elif reduce_type == ReduceType.MAX:
result[key] = self._max_of(key, reduce_group)
elif reduce_type == ReduceType.SCALAR:
result[key] = sum(self.stats[key]) / len(self.stats[key])
else:
raise ValueError(f"Unknown reduce type: {reduce_type}")
keys_to_pop = [k for k, v in result.items() if v is None]
for k in keys_to_pop:
result.pop(k)
return result
def _sum_of(self, key, reduce_group):
values = self.stats[key]
if key not in self.denominators:
x = sum([x.sum() for x in values])
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
else:
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(
f"Denominator `{denominator}` not set for key `{key}`."
)
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, 0.0).sum())
x = sum(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
return float(x)
def _avg_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
ds = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, 0.0).sum())
ds.append(d.sum())
x = sum(xs)
d = sum(ds)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
dist.all_reduce(d, group=reduce_group)
if d == 0:
return None
return x / d
def _min_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, float("inf")).min())
x = min(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
if torch.isinf(x):
return None
return float(x)
def _max_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, -float("inf")).max())
x = max(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
if torch.isinf(x):
return None
return float(x)
DEFAULT_TRACKER = DistributedStatsTracker()
stat = DEFAULT_TRACKER.stat
denominator = DEFAULT_TRACKER.denominator
export = DEFAULT_TRACKER.export
scope = DEFAULT_TRACKER.scope
scalar = DEFAULT_TRACKER.scalar