mirror of https://github.com/inclusionAI/AReaL
memory saving for log_softmax and gather
This commit is contained in:
parent
234e3dd3a0
commit
d70586f85e
|
@ -166,6 +166,27 @@ def build_leave_one_indices(
|
|||
)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
):
|
||||
"""Gather log probs from logits and labels.
|
||||
|
||||
Args:
|
||||
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
each sequence is not used.
|
||||
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
The first value at the beginning of each sequence has no corresponding log prob.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: Log probability with shape [tot_seqlen - #seqs].
|
||||
"""
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
def gather_packed_shifted_log_probs(
|
||||
logits: torch.FloatTensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
|
@ -174,11 +195,11 @@ def gather_packed_shifted_log_probs(
|
|||
"""Gather log probs from packed input_ids and logits.
|
||||
|
||||
Args:
|
||||
logits_ (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
each sequence is not used.
|
||||
cu_seqlens (torch.Tensor): Shape [#seqs + 1]. Indices marking the start
|
||||
and end of each sequences.
|
||||
labels_ (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
and end of each sequence.
|
||||
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
The first value at the beginning of each sequence has no corresponding log prob.
|
||||
|
||||
Returns:
|
||||
|
@ -202,8 +223,7 @@ def gather_packed_shifted_log_probs(
|
|||
# for i in range(cu_seqlens.shape[0] - 1)
|
||||
# ])
|
||||
# shift labels one step to the left and pad it to match the shape of logits
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
log_probs_labels = gather_logprobs(logits, labels)
|
||||
log_probs_labels = log_probs_labels[leave_one_indices]
|
||||
assert log_probs_labels.shape[0] == logits_shape[0] - cu_seqlens.shape[0] + 1, (
|
||||
log_probs_labels.shape,
|
||||
|
|
Loading…
Reference in New Issue