0723_reformatted_5

This commit is contained in:
朱晗 2025-07-23 17:13:49 +08:00
parent 82442b86fd
commit 00b5d878b3
3 changed files with 9 additions and 6 deletions

View File

@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
mb["max_seqlen"] = int(mb["max_seqlen"]) mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False mb["use_cache"] = False
return mb_list return mb_list
def train_batch( def train_batch(
@ -285,12 +284,12 @@ class BaseHFEngine(TrainEngine):
) )
assert total_loss_weight != 0 assert total_loss_weight != 0
dist.all_reduce(total_loss_weight) dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation # Process microbatches with gradient accumulation
for i, (pad_length, padded_mb_input, mb_input) in enumerate( for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs) zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
): ):
outputs = self.model(**padded_mb_input) outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0) logits = outputs.logits.squeeze(0)

View File

@ -217,7 +217,7 @@ class FSDPEngine(BaseHFEngine):
self.optimizer.zero_grad() self.optimizer.zero_grad()
mb_list = self.prepare_mb_list(input_) mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor( total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
) )

View File

@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
new_data = {} new_data = {}
for key, value in data.items(): for key, value in data.items():
if ( if (
key not in ["cu_seqlens", "max_seqlen",] key
not in [
"cu_seqlens",
"max_seqlen",
]
and torch.is_tensor(value) and torch.is_tensor(value)
and value.numel() == total_length and value.numel() == total_length
): ):