mirror of https://github.com/inclusionAI/AReaL
0723_reformatted_5
This commit is contained in:
parent
82442b86fd
commit
00b5d878b3
|
@ -262,7 +262,6 @@ class BaseHFEngine(TrainEngine):
|
||||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||||
mb["use_cache"] = False
|
mb["use_cache"] = False
|
||||||
|
|
||||||
|
|
||||||
return mb_list
|
return mb_list
|
||||||
|
|
||||||
def train_batch(
|
def train_batch(
|
||||||
|
@ -285,12 +284,12 @@ class BaseHFEngine(TrainEngine):
|
||||||
)
|
)
|
||||||
assert total_loss_weight != 0
|
assert total_loss_weight != 0
|
||||||
dist.all_reduce(total_loss_weight)
|
dist.all_reduce(total_loss_weight)
|
||||||
|
|
||||||
# Process microbatches with gradient accumulation
|
# Process microbatches with gradient accumulation
|
||||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||||
):
|
):
|
||||||
|
|
||||||
outputs = self.model(**padded_mb_input)
|
outputs = self.model(**padded_mb_input)
|
||||||
|
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
|
|
|
@ -217,7 +217,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
mb_list = self.prepare_mb_list(input_)
|
mb_list = self.prepare_mb_list(input_)
|
||||||
|
|
||||||
total_loss_weight = torch.tensor(
|
total_loss_weight = torch.tensor(
|
||||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
|
@ -488,7 +488,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
||||||
new_data = {}
|
new_data = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if (
|
if (
|
||||||
key not in ["cu_seqlens", "max_seqlen",]
|
key
|
||||||
|
not in [
|
||||||
|
"cu_seqlens",
|
||||||
|
"max_seqlen",
|
||||||
|
]
|
||||||
and torch.is_tensor(value)
|
and torch.is_tensor(value)
|
||||||
and value.numel() == total_length
|
and value.numel() == total_length
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue