mirror of https://github.com/inclusionAI/AReaL
0723_reformatted_5
This commit is contained in:
parent
82442b86fd
commit
00b5d878b3
|
@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
|
|||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
|
||||
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
|
|
@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
|||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key not in ["cu_seqlens", "max_seqlen",]
|
||||
key
|
||||
not in [
|
||||
"cu_seqlens",
|
||||
"max_seqlen",
|
||||
]
|
||||
and torch.is_tensor(value)
|
||||
and value.numel() == total_length
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue