mirror of https://github.com/inclusionAI/AReaL
normalize loss scale by tokens
This commit is contained in:
parent
241185227d
commit
2436ce519e
|
@ -27,6 +27,10 @@ from realhf.base.recover import StepInfo
|
|||
logger = logging.getLogger("model_api")
|
||||
|
||||
|
||||
class ZeroTotalLossWeightException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerationHyperparameters:
|
||||
"""Generation hyperparameters.
|
||||
|
@ -334,6 +338,8 @@ class PipelinableEngine(abc.ABC):
|
|||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
|
||||
loss_weight_fn: Callable,
|
||||
token_normalize_scope: Literal["global", "dp"],
|
||||
version_steps: int,
|
||||
) -> Tuple[torch.Tensor, Dict] | None:
|
||||
"""Update the model with a batch of data and a loss function.
|
||||
|
@ -345,6 +351,10 @@ class PipelinableEngine(abc.ABC):
|
|||
:param loss_fn: The loss function. It takes the output of the forward pass and the
|
||||
input data, returning the loss and a dictionary of statistics.
|
||||
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
|
||||
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
|
||||
global: average across all micro batches across DP ranks.
|
||||
dp: average across micro batches in current DP rank.
|
||||
:type global_normalize_scope: Literal["global", "dp"]
|
||||
:param version_steps: The global step counter for this experiment,
|
||||
used by the backend to determine the learning rate schedule.
|
||||
:type version_steps: int
|
||||
|
|
|
@ -745,6 +745,8 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable,
|
||||
loss_weight_fn: Callable,
|
||||
token_normalize_scope: str,
|
||||
version_steps: int,
|
||||
):
|
||||
with megatron_ctx():
|
||||
|
@ -765,10 +767,24 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
loss_fn=loss_fn,
|
||||
loss_weight_fn=loss_weight_fn,
|
||||
token_normalize_scope=token_normalize_scope,
|
||||
version_steps=version_steps,
|
||||
)
|
||||
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
if total_loss_weight == 0:
|
||||
raise model_api.ZeroTotalLossWeightException(
|
||||
"The sum of loss weights of all micro batches is zero."
|
||||
)
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
@ -795,6 +811,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
max_seqlen=max_seqlen,
|
||||
).logits
|
||||
loss, _stat = loss_fn(model_output, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
|
||||
if token_normalize_scope == "global":
|
||||
loss_scale *= constants.data_parallel_world_size()
|
||||
loss *= loss_scale
|
||||
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
|
||||
self.engine.optim.scale_loss(loss).backward()
|
||||
for k, v in _stat.items():
|
||||
|
|
|
@ -107,6 +107,8 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable,
|
||||
loss_weight_fn: Callable,
|
||||
token_normalize_scope: str,
|
||||
version_steps: int,
|
||||
):
|
||||
self.optimizer.zero_grad()
|
||||
|
@ -120,10 +122,23 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
loss_fn=loss_fn,
|
||||
loss_weight_fn=loss_weight_fn,
|
||||
token_normalize_scope=token_normalize_scope,
|
||||
version_steps=version_steps,
|
||||
)
|
||||
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
if total_loss_weight == 0:
|
||||
raise model_api.ZeroTotalLossWeightException(
|
||||
"The sum of loss weights of all micro batches is zero."
|
||||
)
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
|
@ -147,6 +162,10 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
max_seqlen=max_seqlen,
|
||||
).logits
|
||||
loss, _stat = loss_fn(model_output, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
|
||||
if token_normalize_scope == "global":
|
||||
loss_scale *= constants.data_parallel_world_size()
|
||||
loss *= loss_scale
|
||||
for k, v in _stat.items():
|
||||
stat[k] += v
|
||||
|
||||
|
|
|
@ -17,7 +17,10 @@ import realhf.impl.model.parallelism.pipeline_parallel.p2p as p2p
|
|||
import realhf.impl.model.parallelism.pipeline_parallel.static_schedule as schedule
|
||||
import realhf.impl.model.utils.cuda_graph as cuda_graph
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.api.core.model_api import GenerationHyperparameters
|
||||
from realhf.api.core.model_api import (
|
||||
GenerationHyperparameters,
|
||||
ZeroTotalLossWeightException,
|
||||
)
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
|
||||
|
@ -668,7 +671,7 @@ class PipeTrainForwardCommInstrSet:
|
|||
"input_cache", micro_batch_id, remove=True
|
||||
)
|
||||
loss, stats = loss_fn(model_output, input_cache)
|
||||
loss = loss / tensor_buffer.get("n_pp_mbs", micro_batch_id)
|
||||
loss = loss * tensor_buffer.get("loss_scale", micro_batch_id)
|
||||
tensor_buffer.put("losses", micro_batch_id, loss)
|
||||
tensor_buffer.put("stats", micro_batch_id, stats)
|
||||
|
||||
|
@ -992,6 +995,8 @@ class PipelineRunner:
|
|||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable,
|
||||
loss_weight_fn: Callable,
|
||||
token_normalize_scope: str,
|
||||
version_steps: int,
|
||||
):
|
||||
# TODO: return whether update success
|
||||
|
@ -1004,6 +1009,14 @@ class PipelineRunner:
|
|||
mb_spec, n_mbs=mb_spec.n_mbs * self.default_train_mbs
|
||||
)
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
@ -1016,6 +1029,10 @@ class PipelineRunner:
|
|||
tensor_buffer = TensorBuffer()
|
||||
for i in range(n_pp_mbs):
|
||||
tensor_buffer.put("n_pp_mbs", i, n_pp_mbs)
|
||||
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
|
||||
if token_normalize_scope == "global":
|
||||
loss_scale *= constants.data_parallel_world_size()
|
||||
tensor_buffer.put("loss_scale", i, loss_scale)
|
||||
tensor_buffer.put("version_steps", i, version_steps)
|
||||
tensor_buffer.put("loss_fn", i, loss_fn)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import dataclasses
|
|||
import functools
|
||||
import itertools
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -186,6 +186,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
mask_too_long: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.adaptive_kl_ctl:
|
||||
|
@ -672,6 +673,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
|
||||
# Run mini-batched PPO training!
|
||||
train_stats = collections.defaultdict(lambda: 0)
|
||||
|
||||
for data in datas:
|
||||
stats = module.train_batch(
|
||||
input_=data,
|
||||
|
@ -685,6 +687,8 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
early_stop_kl=self.early_stop_kl,
|
||||
temperature=self.gconfig.temperature,
|
||||
),
|
||||
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
|
||||
if stats:
|
||||
|
@ -888,6 +892,7 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
mask_too_long: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.adaptive_kl_ctl:
|
||||
|
@ -1111,6 +1116,8 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
kl_adapter=self.kl_adapter,
|
||||
rms=None if not self.value_norm else self.rms,
|
||||
),
|
||||
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
|
||||
if stats:
|
||||
|
|
|
@ -100,6 +100,8 @@ class SFTInterface(model_api.ModelInterface):
|
|||
stat = module.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda: 1,
|
||||
token_normalize_scope="dp",
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
)
|
||||
|
|
|
@ -496,6 +496,8 @@ def _test_para_realloc(
|
|||
if not is_critic
|
||||
else compute_critic_loss
|
||||
),
|
||||
loss_weight_fn=lambda: 1,
|
||||
token_normalize_scope="dp",
|
||||
version_steps=i,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue