0723_reformatted_5

This commit is contained in:
朱晗 2025-07-23 17:13:49 +08:00
parent 82442b86fd
commit 00b5d878b3
3 changed files with 9 additions and 6 deletions

View File

@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(

View File

@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
new_data = {}
for key, value in data.items():
if (
key not in ["cu_seqlens", "max_seqlen",]
key
not in [
"cu_seqlens",
"max_seqlen",
]
and torch.is_tensor(value)
and value.numel() == total_length
):