This commit is contained in:
meizhiyu.mzy 2025-02-27 15:40:21 +08:00
parent 2436ce519e
commit 35f50e3f53
5 changed files with 30 additions and 17 deletions

View File

@ -338,9 +338,9 @@ class PipelinableEngine(abc.ABC):
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
loss_weight_fn: Callable,
token_normalize_scope: Literal["global", "dp"],
loss_weight_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor],
version_steps: int,
token_normalize_scope: Literal["global", "dp"] = "global",
) -> Tuple[torch.Tensor, Dict] | None:
"""Update the model with a batch of data and a loss function.
@ -351,10 +351,9 @@ class PipelinableEngine(abc.ABC):
:param loss_fn: The loss function. It takes the output of the forward pass and the
input data, returning the loss and a dictionary of statistics.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
global: average across all micro batches across DP ranks.
dp: average across micro batches in current DP rank.
:type global_normalize_scope: Literal["global", "dp"]
:param loss_weight_fn: This function is used to calculate weight when normalizing
loss across micro batches.
:type loss_weight_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor]
:param version_steps: The global step counter for this experiment,
used by the backend to determine the learning rate schedule.
:type version_steps: int
@ -364,6 +363,11 @@ class PipelinableEngine(abc.ABC):
which automatically schedules the forward and backward passes. For non-pipelined
training, forward and backward passes are executed iteratively over mini-batches
to accumulate gradients. If None, the batch will not be split.
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
global: average across all micro batches across DP ranks.
dp: average across micro batches in current DP rank.
Default to "global".
:type global_normalize_scope: Literal["global", "dp"]
"""
raise NotImplementedError()

View File

@ -664,7 +664,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
loss: torch.Tensor = tensor_buffer.get(
"losses", micro_batch_id, remove=True
)
loss = self.engine.optim.scale_loss(loss)
loss.backward()
tensor_buffer.put("losses", micro_batch_id, loss.detach().clone())
return
@ -777,7 +776,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
if total_loss_weight == 0:
@ -813,10 +812,15 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
loss, _stat = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
# Megatron will average gradients across DP ranks.
# If we normalize loss across micro batches of all DP ranks,
# we should revert the effect of gradient averaging in megatron
# to make sure loss from each token is scaled properly.
loss_scale *= constants.data_parallel_world_size()
loss_scale *= self.engine.optim.get_loss_scale()
loss *= loss_scale
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
self.engine.optim.scale_loss(loss).backward()
loss.backward()
for k, v in _stat.items():
stat[k] += v

View File

@ -132,9 +132,7 @@ class MockTrainEngine(model_api.PipelinableEngine):
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
if total_loss_weight == 0:
raise model_api.ZeroTotalLossWeightException(
"The sum of loss weights of all micro batches is zero."

View File

@ -754,6 +754,7 @@ class PipeTrainForwardCommInstrSet:
@dataclasses.dataclass
class PipeTrainInstrSet:
engine: Any
def _exec_optimizer_step(self, *args, **kwargs):
raise NotImplementedError()
@ -1013,9 +1014,7 @@ class PipelineRunner:
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
total_loss_weight = dist.all_reduce(
total_loss_weight, group=constants.data_parallel_group()
)
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
if constants.parallelism_rank() == 0:
logger.info(
@ -1031,7 +1030,12 @@ class PipelineRunner:
tensor_buffer.put("n_pp_mbs", i, n_pp_mbs)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
# Megatron will average gradients across DP ranks.
# If we normalize loss across micro batches of all DP ranks,
# we should revert the effect of gradient averaging in megatron
# to make sure loss from each token is scaled properly.
loss_scale *= constants.data_parallel_world_size()
loss_scale *= instr_set.engine.optim.get_loss_scale()
tensor_buffer.put("loss_scale", i, loss_scale)
tensor_buffer.put("version_steps", i, version_steps)
tensor_buffer.put("loss_fn", i, loss_fn)

View File

@ -89,6 +89,7 @@ def compute_packed_sft_loss(
class SFTInterface(model_api.ModelInterface):
token_normalize_scope: Literal["global", "dp"] = "global"
def train_step(
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
@ -100,8 +101,10 @@ class SFTInterface(model_api.ModelInterface):
stat = module.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda: 1,
token_normalize_scope="dp",
loss_weight_fn=lambda x: x.data["prompt_mask"]
.logical_not()
.count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
mb_spec=mb_spec,
version_steps=model.version.global_step,
)