mirror of https://github.com/inclusionAI/AReaL
fix: full_tensor() should happen in all rank (#187)
Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>
This commit is contained in:
parent
fdd85e2e55
commit
e395df1430
|
@ -182,16 +182,18 @@ class FSDPEngine(BaseHFEngine):
|
||||||
|
|
||||||
def _update_weights_from_distributed(self):
|
def _update_weights_from_distributed(self):
|
||||||
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
|
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
|
||||||
if dist.get_rank() == 0:
|
|
||||||
for name, param in self.model.named_parameters():
|
for name, param in self.model.named_parameters():
|
||||||
if isinstance(param.data, DTensor):
|
if isinstance(param.data, DTensor):
|
||||||
tensor = param.data.full_tensor()
|
tensor = param.data.full_tensor()
|
||||||
else:
|
else:
|
||||||
tensor = param.data
|
tensor = param.data
|
||||||
|
if dist.get_rank() == 0:
|
||||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
||||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||||
del tensor # optional, for memory hygiene
|
dist.barrier()
|
||||||
torch.cuda.empty_cache()
|
del tensor # optional, for memory hygiene
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
||||||
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
||||||
|
|
Loading…
Reference in New Issue