PullRequest: 46 Call all runtime barriers upon CPU process groups and fix the SGLang performance with TP > 1

Merge branch fw/fix-gpu-barrier of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/46?tab=diff

Signed-off-by: 闻通 <albert.zty@antgroup.com>


* .
This commit is contained in:
博惟 2025-03-19 16:29:31 +08:00
parent 9a9d86112e
commit 9c55827b32
6 changed files with 45 additions and 17 deletions

View File

@ -179,6 +179,9 @@ _grids: Dict["ModelName", "ParallelGrid"] = {}
_pgroups: Dict["ModelName", Any] = (
{}
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
_cpu_pgroups: Dict["ModelName", Any] = (
{}
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
_pgroup_ranks: Dict["ModelName", List[int]] = {}
_self_group = None
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
@ -261,6 +264,13 @@ def set_parallelism_group(model_name: "ModelName", pgroup, ranks):
_pgroup_ranks[model_name] = ranks
def set_cpu_parallelism_group(model_name: "ModelName", pgroup):
global _cpu_pgroups
if model_name in _cpu_pgroups:
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
_cpu_pgroups[model_name] = pgroup
def set_self_group(pgroup):
global _self_group
if _self_group is not None:
@ -384,6 +394,15 @@ def parallelism_group():
return _pgroups[_model_name]
def cpu_parallelism_group():
"""Returns the GLOO 3D parallelism group of a specific model."""
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
if _cpu_pgroups.get(_model_name, None) is None:
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
return _cpu_pgroups[_model_name]
def parallelism_group_ranks():
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")

View File

@ -182,7 +182,7 @@ class SGLangGenerationEngine(PipelinableEngine):
request_timeout: int = 1800,
):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
return
# Start the serving process
self.server_proc = mp.Process(
@ -209,7 +209,7 @@ class SGLangGenerationEngine(PipelinableEngine):
# offload weights/cache
self.hybrid_train = hybrid_train
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
def __del__(self):
if hasattr(self, "server_proc"):
@ -352,7 +352,7 @@ class SGLangGenerationEngine(PipelinableEngine):
"because we force to skip_tokenizer_init."
)
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
return None, None, None
results = asyncio.run(
@ -363,12 +363,12 @@ class SGLangGenerationEngine(PipelinableEngine):
gconfig=gconfig,
)
)
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
return results
def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
return
async def _fn():
@ -379,7 +379,7 @@ class SGLangGenerationEngine(PipelinableEngine):
await client.async_update_weights_from_disk(path)
asyncio.run(_fn())
dist.barrier(group=constants.model_parallel_group())
dist.barrier(group=constants.model_parallel_cpu_group())
@dataclasses.dataclass

View File

@ -136,6 +136,8 @@ def setup_global_comm(
for model_name, ranks in mw_ranks.items():
model_groups[model_name] = topology.new_or_get_group(ranks, backend=backend)
constants.set_parallelism_group(model_name, model_groups[model_name], ranks)
cpu_group = topology.new_or_get_group(ranks, backend="gloo")
constants.set_cpu_parallelism_group(model_name, cpu_group)
self_group = None
for i in range(world_size):

View File

@ -62,6 +62,9 @@ class ParamReallocInfo:
param_realloc_model_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_model_cpu_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup]
param_realloc_src_ranks: Dict[ParamReallocPair, int]
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]]
@ -270,6 +273,7 @@ def setup_param_realloc(
param_realloc_src_ranks = {}
param_realloc_dst_ranks = {}
param_realloc_model_group = {}
param_realloc_model_cpu_group = {}
if param_realloc_pairs is not None:
for src, dst in param_realloc_pairs:
_create_param_realloc_groups(
@ -296,11 +300,15 @@ def setup_param_realloc(
param_realloc_model_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)))
)
param_realloc_model_cpu_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)), backend="gloo")
)
return ParamReallocInfo(
param_realloc_groups=param_realloc_groups,
param_realloc_src_ranks=param_realloc_src_ranks,
param_realloc_dst_ranks=param_realloc_dst_ranks,
param_realloc_model_group=param_realloc_model_group,
param_realloc_model_cpu_group=param_realloc_model_cpu_group,
)

View File

@ -175,10 +175,10 @@ class ReaLModel(nn.Module):
self.contiguous_param = None
self.hf_model_family = hf_model_family
def save_to_hf(self, tokenizer, save_dir):
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
def load_from_hf(self, load_dir):
return getattr(self, f"from_{self.hf_model_family}")(load_dir)

View File

@ -24,7 +24,6 @@ import numpy as np
import pynvml
import tabulate
import torch
import torch.distributed
import torch.distributed as dist
import torch.utils.data
@ -761,7 +760,7 @@ class ModelWorker(worker_base.Worker):
or self.__enable_memory_dump
):
torch.cuda.synchronize()
torch.distributed.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
# pfer can be a null context if enable_profiler is False
pfer = get_pytorch_profiler(
kernel_only=False, enabled=self.__enable_profiler
@ -780,7 +779,7 @@ class ModelWorker(worker_base.Worker):
or self.__enable_memory_dump
):
pfer.__exit__(None, None, None)
torch.distributed.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
torch.cuda.synchronize()
tok = time.perf_counter()
rpc_time = tok - tik
@ -913,7 +912,7 @@ class ModelWorker(worker_base.Worker):
eval_scores.update(scores)
res.metadata.pop("scores")
dist.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
if len(eval_scores) > 0 and self._dp_rank == 0 and self._is_dp_head:
with open(
eval_scores_path,
@ -949,7 +948,7 @@ class ModelWorker(worker_base.Worker):
self._clear_memory()
if constants.use_cuda():
torch.cuda.synchronize()
dist.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
return res
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
@ -991,7 +990,7 @@ class ModelWorker(worker_base.Worker):
with constants.model_scope(from_model_name):
from_model_ranks = constants.parallelism_group_ranks()
if not param_realloc_comm.is_trainable(from_model_name):
if torch.distributed.get_rank() not in from_model_ranks:
if dist.get_rank() not in from_model_ranks:
return
if not isinstance(self.__unwrapped_models[from_model_name], ReaLModel):
# We can only release the memory of ReaLModel,
@ -1017,7 +1016,7 @@ class ModelWorker(worker_base.Worker):
save_dir=realloc_dir,
)
self.__save_model(save_meta)
g = self.__param_realloc_info.param_realloc_model_group[
g = self.__param_realloc_info.param_realloc_model_cpu_group[
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
]
dist.barrier(group=g)
@ -1029,7 +1028,7 @@ class ModelWorker(worker_base.Worker):
self.__load_model(load_meta)
# Remove the reallocated checkpoint.
with constants.model_scope(to_model_name):
dist.barrier(constants.parallelism_group())
dist.barrier(constants.cpu_parallelism_group())
if constants.parallelism_rank() == 0:
shutil.rmtree(realloc_dir, ignore_errors=True)
os.makedirs(realloc_dir, exist_ok=True)
@ -1097,7 +1096,7 @@ class ModelWorker(worker_base.Worker):
).is_symlink():
os.unlink(save_root / fn)
shutil.rmtree(save_dir, ignore_errors=True)
dist.barrier(constants.parallelism_group())
dist.barrier(constants.cpu_parallelism_group())
self._interface.save(self._model, save_dir)
# The `save` method of the interface may be empty.
# We only save the backend state if the parameters have been indeed saved.