memory saving for log_softmax and gather

This commit is contained in:
穰侯 2025-03-14 10:18:28 +08:00
parent 234e3dd3a0
commit d70586f85e
1 changed files with 25 additions and 5 deletions

View File

@ -166,6 +166,27 @@ def build_leave_one_indices(
)
@torch.compile
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
):
"""Gather log probs from logits and labels.
Args:
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
each sequence is not used.
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
The first value at the beginning of each sequence has no corresponding log prob.
Returns:
torch.FloatTensor: Log probability with shape [tot_seqlen - #seqs].
"""
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
def gather_packed_shifted_log_probs(
logits: torch.FloatTensor,
cu_seqlens: torch.Tensor,
@ -174,11 +195,11 @@ def gather_packed_shifted_log_probs(
"""Gather log probs from packed input_ids and logits.
Args:
logits_ (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
each sequence is not used.
cu_seqlens (torch.Tensor): Shape [#seqs + 1]. Indices marking the start
and end of each sequences.
labels_ (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
and end of each sequence.
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
The first value at the beginning of each sequence has no corresponding log prob.
Returns:
@ -202,8 +223,7 @@ def gather_packed_shifted_log_probs(
# for i in range(cu_seqlens.shape[0] - 1)
# ])
# shift labels one step to the left and pad it to match the shape of logits
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
log_probs_labels = gather_logprobs(logits, labels)
log_probs_labels = log_probs_labels[leave_one_indices]
assert log_probs_labels.shape[0] == logits_shape[0] - cu_seqlens.shape[0] + 1, (
log_probs_labels.shape,