diff --git a/arealite/engine/fsdp_engine.py b/arealite/engine/fsdp_engine.py index 65fb1f7..b07e284 100644 --- a/arealite/engine/fsdp_engine.py +++ b/arealite/engine/fsdp_engine.py @@ -182,16 +182,18 @@ class FSDPEngine(BaseHFEngine): def _update_weights_from_distributed(self): """Broadcast parameters from rank 0 (FSDP2 compatible).""" - if dist.get_rank() == 0: - for name, param in self.model.named_parameters(): - if isinstance(param.data, DTensor): - tensor = param.data.full_tensor() - else: - tensor = param.data + + for name, param in self.model.named_parameters(): + if isinstance(param.data, DTensor): + tensor = param.data.full_tensor() + else: + tensor = param.data + if dist.get_rank() == 0: print(f"Broadcasting {name} with shape {tensor.shape}", flush=True) dist.broadcast(tensor, src=0, group=self.weight_update_group) - del tensor # optional, for memory hygiene - torch.cuda.empty_cache() + dist.barrier() + del tensor # optional, for memory hygiene + torch.cuda.empty_cache() def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]: """Return a dict mapping param name to its shape (expanded if DTensor)."""