mirror of https://github.com/inclusionAI/AReaL
0723_reformatted_5
This commit is contained in:
parent
82442b86fd
commit
00b5d878b3
|
@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
|
|||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
|
||||
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
@ -285,12 +284,12 @@ class BaseHFEngine(TrainEngine):
|
|||
)
|
||||
assert total_loss_weight != 0
|
||||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
|
||||
):
|
||||
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
|
|
|
@ -217,7 +217,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self.prepare_mb_list(input_)
|
||||
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
|
|
|
@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
|||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key not in ["cu_seqlens", "max_seqlen",]
|
||||
key
|
||||
not in [
|
||||
"cu_seqlens",
|
||||
"max_seqlen",
|
||||
]
|
||||
and torch.is_tensor(value)
|
||||
and value.numel() == total_length
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue