0723_reformatted_5

This commit is contained in:
朱晗 2025-07-23 17:13:49 +08:00
parent 82442b86fd
commit 00b5d878b3
3 changed files with 9 additions and 6 deletions

View File

@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
@ -285,12 +284,12 @@ class BaseHFEngine(TrainEngine):
)
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)

View File

@ -217,7 +217,7 @@ class FSDPEngine(BaseHFEngine):
self.optimizer.zero_grad()
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)

View File

@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
new_data = {}
for key, value in data.items():
if (
key not in ["cu_seqlens", "max_seqlen",]
key
not in [
"cu_seqlens",
"max_seqlen",
]
and torch.is_tensor(value)
and value.numel() == total_length
):