mirror of https://github.com/inclusionAI/AReaL
fix: full_tensor() should happen in all rank (#187)
Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>
This commit is contained in:
parent
fdd85e2e55
commit
e395df1430
|
@ -182,14 +182,16 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
def _update_weights_from_distributed(self):
|
||||
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
|
||||
if dist.get_rank() == 0:
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||
dist.barrier()
|
||||
del tensor # optional, for memory hygiene
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
Loading…
Reference in New Issue