fix: full_tensor() should happen in all rank (#187)

Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>
This commit is contained in:
ChangyiYang 2025-07-20 19:57:38 -07:00 committed by GitHub
parent fdd85e2e55
commit e395df1430
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 8 deletions

View File

@ -182,14 +182,16 @@ class FSDPEngine(BaseHFEngine):
def _update_weights_from_distributed(self):
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
if dist.get_rank() == 0:
for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
if dist.get_rank() == 0:
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
dist.broadcast(tensor, src=0, group=self.weight_update_group)
dist.barrier()
del tensor # optional, for memory hygiene
torch.cuda.empty_cache()