mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
2436ce519e
commit
35f50e3f53
|
@ -338,9 +338,9 @@ class PipelinableEngine(abc.ABC):
|
|||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
|
||||
loss_weight_fn: Callable,
|
||||
token_normalize_scope: Literal["global", "dp"],
|
||||
loss_weight_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor],
|
||||
version_steps: int,
|
||||
token_normalize_scope: Literal["global", "dp"] = "global",
|
||||
) -> Tuple[torch.Tensor, Dict] | None:
|
||||
"""Update the model with a batch of data and a loss function.
|
||||
|
||||
|
@ -351,10 +351,9 @@ class PipelinableEngine(abc.ABC):
|
|||
:param loss_fn: The loss function. It takes the output of the forward pass and the
|
||||
input data, returning the loss and a dictionary of statistics.
|
||||
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
|
||||
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
|
||||
global: average across all micro batches across DP ranks.
|
||||
dp: average across micro batches in current DP rank.
|
||||
:type global_normalize_scope: Literal["global", "dp"]
|
||||
:param loss_weight_fn: This function is used to calculate weight when normalizing
|
||||
loss across micro batches.
|
||||
:type loss_weight_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor]
|
||||
:param version_steps: The global step counter for this experiment,
|
||||
used by the backend to determine the learning rate schedule.
|
||||
:type version_steps: int
|
||||
|
@ -364,6 +363,11 @@ class PipelinableEngine(abc.ABC):
|
|||
which automatically schedules the forward and backward passes. For non-pipelined
|
||||
training, forward and backward passes are executed iteratively over mini-batches
|
||||
to accumulate gradients. If None, the batch will not be split.
|
||||
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
|
||||
global: average across all micro batches across DP ranks.
|
||||
dp: average across micro batches in current DP rank.
|
||||
Default to "global".
|
||||
:type global_normalize_scope: Literal["global", "dp"]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -664,7 +664,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
|||
loss: torch.Tensor = tensor_buffer.get(
|
||||
"losses", micro_batch_id, remove=True
|
||||
)
|
||||
loss = self.engine.optim.scale_loss(loss)
|
||||
loss.backward()
|
||||
tensor_buffer.put("losses", micro_batch_id, loss.detach().clone())
|
||||
return
|
||||
|
@ -777,7 +776,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
if total_loss_weight == 0:
|
||||
|
@ -813,10 +812,15 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
loss, _stat = loss_fn(model_output, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
|
||||
if token_normalize_scope == "global":
|
||||
# Megatron will average gradients across DP ranks.
|
||||
# If we normalize loss across micro batches of all DP ranks,
|
||||
# we should revert the effect of gradient averaging in megatron
|
||||
# to make sure loss from each token is scaled properly.
|
||||
loss_scale *= constants.data_parallel_world_size()
|
||||
loss_scale *= self.engine.optim.get_loss_scale()
|
||||
loss *= loss_scale
|
||||
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
|
||||
self.engine.optim.scale_loss(loss).backward()
|
||||
loss.backward()
|
||||
for k, v in _stat.items():
|
||||
stat[k] += v
|
||||
|
||||
|
|
|
@ -132,9 +132,7 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
|
||||
if total_loss_weight == 0:
|
||||
raise model_api.ZeroTotalLossWeightException(
|
||||
"The sum of loss weights of all micro batches is zero."
|
||||
|
|
|
@ -754,6 +754,7 @@ class PipeTrainForwardCommInstrSet:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class PipeTrainInstrSet:
|
||||
engine: Any
|
||||
|
||||
def _exec_optimizer_step(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
@ -1013,9 +1014,7 @@ class PipelineRunner:
|
|||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
if token_normalize_scope == "global":
|
||||
total_loss_weight = dist.all_reduce(
|
||||
total_loss_weight, group=constants.data_parallel_group()
|
||||
)
|
||||
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
|
@ -1031,7 +1030,12 @@ class PipelineRunner:
|
|||
tensor_buffer.put("n_pp_mbs", i, n_pp_mbs)
|
||||
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
|
||||
if token_normalize_scope == "global":
|
||||
# Megatron will average gradients across DP ranks.
|
||||
# If we normalize loss across micro batches of all DP ranks,
|
||||
# we should revert the effect of gradient averaging in megatron
|
||||
# to make sure loss from each token is scaled properly.
|
||||
loss_scale *= constants.data_parallel_world_size()
|
||||
loss_scale *= instr_set.engine.optim.get_loss_scale()
|
||||
tensor_buffer.put("loss_scale", i, loss_scale)
|
||||
tensor_buffer.put("version_steps", i, version_steps)
|
||||
tensor_buffer.put("loss_fn", i, loss_fn)
|
||||
|
|
|
@ -89,6 +89,7 @@ def compute_packed_sft_loss(
|
|||
|
||||
|
||||
class SFTInterface(model_api.ModelInterface):
|
||||
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||
|
||||
def train_step(
|
||||
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
|
||||
|
@ -100,8 +101,10 @@ class SFTInterface(model_api.ModelInterface):
|
|||
stat = module.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda: 1,
|
||||
token_normalize_scope="dp",
|
||||
loss_weight_fn=lambda x: x.data["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue