normalize loss scale by tokens

This commit is contained in:
meizhiyu.mzy 2025-02-27 11:36:26 +08:00
parent 241185227d
commit 2436ce519e
7 changed files with 80 additions and 3 deletions

View File

@ -27,6 +27,10 @@ from realhf.base.recover import StepInfo
logger = logging.getLogger("model_api")
class ZeroTotalLossWeightException(Exception):
pass
@dataclasses.dataclass
class GenerationHyperparameters:
"""Generation hyperparameters.
@ -334,6 +338,8 @@ class PipelinableEngine(abc.ABC):
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
loss_weight_fn: Callable,
token_normalize_scope: Literal["global", "dp"],
version_steps: int,
) -> Tuple[torch.Tensor, Dict] | None:
"""Update the model with a batch of data and a loss function.
@ -345,6 +351,10 @@ class PipelinableEngine(abc.ABC):
:param loss_fn: The loss function. It takes the output of the forward pass and the
input data, returning the loss and a dictionary of statistics.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
global: average across all micro batches across DP ranks.
dp: average across micro batches in current DP rank.
:type global_normalize_scope: Literal["global", "dp"]
:param version_steps: The global step counter for this experiment,
used by the backend to determine the learning rate schedule.
:type version_steps: int

View File

@ -745,6 +745,8 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable,
loss_weight_fn: Callable,
token_normalize_scope: str,
version_steps: int,
):
with megatron_ctx():
@ -765,10 +767,24 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
input_=input_,
mb_spec=mb_spec,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
token_normalize_scope=token_normalize_scope,
version_steps=version_steps,
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
if total_loss_weight == 0:
raise model_api.ZeroTotalLossWeightException(
"The sum of loss weights of all micro batches is zero."
)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
@ -795,6 +811,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
max_seqlen=max_seqlen,
).logits
loss, _stat = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
loss_scale *= constants.data_parallel_world_size()
loss *= loss_scale
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
self.engine.optim.scale_loss(loss).backward()
for k, v in _stat.items():

View File

@ -107,6 +107,8 @@ class MockTrainEngine(model_api.PipelinableEngine):
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable,
loss_weight_fn: Callable,
token_normalize_scope: str,
version_steps: int,
):
self.optimizer.zero_grad()
@ -120,10 +122,23 @@ class MockTrainEngine(model_api.PipelinableEngine):
input_=input_,
mb_spec=mb_spec,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
token_normalize_scope=token_normalize_scope,
version_steps=version_steps,
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
if total_loss_weight == 0:
raise model_api.ZeroTotalLossWeightException(
"The sum of loss weights of all micro batches is zero."
)
if constants.parallelism_rank() == 0:
logger.info(
@ -147,6 +162,10 @@ class MockTrainEngine(model_api.PipelinableEngine):
max_seqlen=max_seqlen,
).logits
loss, _stat = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
loss_scale *= constants.data_parallel_world_size()
loss *= loss_scale
for k, v in _stat.items():
stat[k] += v

View File

@ -17,7 +17,10 @@ import realhf.impl.model.parallelism.pipeline_parallel.p2p as p2p
import realhf.impl.model.parallelism.pipeline_parallel.static_schedule as schedule
import realhf.impl.model.utils.cuda_graph as cuda_graph
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.core.model_api import (
GenerationHyperparameters,
ZeroTotalLossWeightException,
)
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
@ -668,7 +671,7 @@ class PipeTrainForwardCommInstrSet:
"input_cache", micro_batch_id, remove=True
)
loss, stats = loss_fn(model_output, input_cache)
loss = loss / tensor_buffer.get("n_pp_mbs", micro_batch_id)
loss = loss * tensor_buffer.get("loss_scale", micro_batch_id)
tensor_buffer.put("losses", micro_batch_id, loss)
tensor_buffer.put("stats", micro_batch_id, stats)
@ -992,6 +995,8 @@ class PipelineRunner:
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable,
loss_weight_fn: Callable,
token_normalize_scope: str,
version_steps: int,
):
# TODO: return whether update success
@ -1004,6 +1009,14 @@ class PipelineRunner:
mb_spec, n_mbs=mb_spec.n_mbs * self.default_train_mbs
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
@ -1016,6 +1029,10 @@ class PipelineRunner:
tensor_buffer = TensorBuffer()
for i in range(n_pp_mbs):
tensor_buffer.put("n_pp_mbs", i, n_pp_mbs)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
loss_scale *= constants.data_parallel_world_size()
tensor_buffer.put("loss_scale", i, loss_scale)
tensor_buffer.put("version_steps", i, version_steps)
tensor_buffer.put("loss_fn", i, loss_fn)

View File

@ -7,7 +7,7 @@ import dataclasses
import functools
import itertools
import time
from typing import Dict, Optional, Tuple
from typing import Dict, Literal, Optional, Tuple
import torch
import torch.distributed as dist
@ -186,6 +186,7 @@ class PPOActorInterface(model_api.ModelInterface):
mask_too_long: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
def __post_init__(self):
if self.adaptive_kl_ctl:
@ -672,6 +673,7 @@ class PPOActorInterface(model_api.ModelInterface):
# Run mini-batched PPO training!
train_stats = collections.defaultdict(lambda: 0)
for data in datas:
stats = module.train_batch(
input_=data,
@ -685,6 +687,8 @@ class PPOActorInterface(model_api.ModelInterface):
early_stop_kl=self.early_stop_kl,
temperature=self.gconfig.temperature,
),
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
if stats:
@ -888,6 +892,7 @@ class PPOCriticInterface(model_api.ModelInterface):
mask_too_long: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
def __post_init__(self):
if self.adaptive_kl_ctl:
@ -1111,6 +1116,8 @@ class PPOCriticInterface(model_api.ModelInterface):
kl_adapter=self.kl_adapter,
rms=None if not self.value_norm else self.rms,
),
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
if stats:

View File

@ -100,6 +100,8 @@ class SFTInterface(model_api.ModelInterface):
stat = module.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda: 1,
token_normalize_scope="dp",
mb_spec=mb_spec,
version_steps=model.version.global_step,
)

View File

@ -496,6 +496,8 @@ def _test_para_realloc(
if not is_critic
else compute_critic_loss
),
loss_weight_fn=lambda: 1,
token_normalize_scope="dp",
version_steps=i,
)