mirror of https://github.com/inclusionAI/AReaL
PullRequest: 46 Call all runtime barriers upon CPU process groups and fix the SGLang performance with TP > 1
Merge branch fw/fix-gpu-barrier of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/46?tab=diff Signed-off-by: 闻通 <albert.zty@antgroup.com> * .
This commit is contained in:
parent
9a9d86112e
commit
9c55827b32
|
@ -179,6 +179,9 @@ _grids: Dict["ModelName", "ParallelGrid"] = {}
|
|||
_pgroups: Dict["ModelName", Any] = (
|
||||
{}
|
||||
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
||||
_cpu_pgroups: Dict["ModelName", Any] = (
|
||||
{}
|
||||
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
||||
_pgroup_ranks: Dict["ModelName", List[int]] = {}
|
||||
_self_group = None
|
||||
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
|
||||
|
@ -261,6 +264,13 @@ def set_parallelism_group(model_name: "ModelName", pgroup, ranks):
|
|||
_pgroup_ranks[model_name] = ranks
|
||||
|
||||
|
||||
def set_cpu_parallelism_group(model_name: "ModelName", pgroup):
|
||||
global _cpu_pgroups
|
||||
if model_name in _cpu_pgroups:
|
||||
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
|
||||
_cpu_pgroups[model_name] = pgroup
|
||||
|
||||
|
||||
def set_self_group(pgroup):
|
||||
global _self_group
|
||||
if _self_group is not None:
|
||||
|
@ -384,6 +394,15 @@ def parallelism_group():
|
|||
return _pgroups[_model_name]
|
||||
|
||||
|
||||
def cpu_parallelism_group():
|
||||
"""Returns the GLOO 3D parallelism group of a specific model."""
|
||||
if _model_name is None:
|
||||
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
||||
if _cpu_pgroups.get(_model_name, None) is None:
|
||||
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
|
||||
return _cpu_pgroups[_model_name]
|
||||
|
||||
|
||||
def parallelism_group_ranks():
|
||||
if _model_name is None:
|
||||
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
||||
|
|
|
@ -182,7 +182,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
request_timeout: int = 1800,
|
||||
):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return
|
||||
# Start the serving process
|
||||
self.server_proc = mp.Process(
|
||||
|
@ -209,7 +209,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
# offload weights/cache
|
||||
self.hybrid_train = hybrid_train
|
||||
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "server_proc"):
|
||||
|
@ -352,7 +352,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
"because we force to skip_tokenizer_init."
|
||||
)
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return None, None, None
|
||||
|
||||
results = asyncio.run(
|
||||
|
@ -363,12 +363,12 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
gconfig=gconfig,
|
||||
)
|
||||
)
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return results
|
||||
|
||||
def update_weights_from_disk(self, path):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return
|
||||
|
||||
async def _fn():
|
||||
|
@ -379,7 +379,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
await client.async_update_weights_from_disk(path)
|
||||
|
||||
asyncio.run(_fn())
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -136,6 +136,8 @@ def setup_global_comm(
|
|||
for model_name, ranks in mw_ranks.items():
|
||||
model_groups[model_name] = topology.new_or_get_group(ranks, backend=backend)
|
||||
constants.set_parallelism_group(model_name, model_groups[model_name], ranks)
|
||||
cpu_group = topology.new_or_get_group(ranks, backend="gloo")
|
||||
constants.set_cpu_parallelism_group(model_name, cpu_group)
|
||||
|
||||
self_group = None
|
||||
for i in range(world_size):
|
||||
|
|
|
@ -62,6 +62,9 @@ class ParamReallocInfo:
|
|||
param_realloc_model_group: Dict[
|
||||
ParamReallocModelPair, torch.distributed.ProcessGroup
|
||||
]
|
||||
param_realloc_model_cpu_group: Dict[
|
||||
ParamReallocModelPair, torch.distributed.ProcessGroup
|
||||
]
|
||||
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup]
|
||||
param_realloc_src_ranks: Dict[ParamReallocPair, int]
|
||||
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]]
|
||||
|
@ -270,6 +273,7 @@ def setup_param_realloc(
|
|||
param_realloc_src_ranks = {}
|
||||
param_realloc_dst_ranks = {}
|
||||
param_realloc_model_group = {}
|
||||
param_realloc_model_cpu_group = {}
|
||||
if param_realloc_pairs is not None:
|
||||
for src, dst in param_realloc_pairs:
|
||||
_create_param_realloc_groups(
|
||||
|
@ -296,11 +300,15 @@ def setup_param_realloc(
|
|||
param_realloc_model_group[ParamReallocModelPair(src, dst)] = (
|
||||
topology.new_or_get_group(list(sorted(pair_mw_ranks)))
|
||||
)
|
||||
param_realloc_model_cpu_group[ParamReallocModelPair(src, dst)] = (
|
||||
topology.new_or_get_group(list(sorted(pair_mw_ranks)), backend="gloo")
|
||||
)
|
||||
return ParamReallocInfo(
|
||||
param_realloc_groups=param_realloc_groups,
|
||||
param_realloc_src_ranks=param_realloc_src_ranks,
|
||||
param_realloc_dst_ranks=param_realloc_dst_ranks,
|
||||
param_realloc_model_group=param_realloc_model_group,
|
||||
param_realloc_model_cpu_group=param_realloc_model_cpu_group,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -175,10 +175,10 @@ class ReaLModel(nn.Module):
|
|||
self.contiguous_param = None
|
||||
|
||||
self.hf_model_family = hf_model_family
|
||||
|
||||
|
||||
def save_to_hf(self, tokenizer, save_dir):
|
||||
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
|
||||
|
||||
|
||||
def load_from_hf(self, load_dir):
|
||||
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ import numpy as np
|
|||
import pynvml
|
||||
import tabulate
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
|
@ -761,7 +760,7 @@ class ModelWorker(worker_base.Worker):
|
|||
or self.__enable_memory_dump
|
||||
):
|
||||
torch.cuda.synchronize()
|
||||
torch.distributed.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
# pfer can be a null context if enable_profiler is False
|
||||
pfer = get_pytorch_profiler(
|
||||
kernel_only=False, enabled=self.__enable_profiler
|
||||
|
@ -780,7 +779,7 @@ class ModelWorker(worker_base.Worker):
|
|||
or self.__enable_memory_dump
|
||||
):
|
||||
pfer.__exit__(None, None, None)
|
||||
torch.distributed.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
torch.cuda.synchronize()
|
||||
tok = time.perf_counter()
|
||||
rpc_time = tok - tik
|
||||
|
@ -913,7 +912,7 @@ class ModelWorker(worker_base.Worker):
|
|||
eval_scores.update(scores)
|
||||
|
||||
res.metadata.pop("scores")
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
if len(eval_scores) > 0 and self._dp_rank == 0 and self._is_dp_head:
|
||||
with open(
|
||||
eval_scores_path,
|
||||
|
@ -949,7 +948,7 @@ class ModelWorker(worker_base.Worker):
|
|||
self._clear_memory()
|
||||
if constants.use_cuda():
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
return res
|
||||
|
||||
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
|
||||
|
@ -991,7 +990,7 @@ class ModelWorker(worker_base.Worker):
|
|||
with constants.model_scope(from_model_name):
|
||||
from_model_ranks = constants.parallelism_group_ranks()
|
||||
if not param_realloc_comm.is_trainable(from_model_name):
|
||||
if torch.distributed.get_rank() not in from_model_ranks:
|
||||
if dist.get_rank() not in from_model_ranks:
|
||||
return
|
||||
if not isinstance(self.__unwrapped_models[from_model_name], ReaLModel):
|
||||
# We can only release the memory of ReaLModel,
|
||||
|
@ -1017,7 +1016,7 @@ class ModelWorker(worker_base.Worker):
|
|||
save_dir=realloc_dir,
|
||||
)
|
||||
self.__save_model(save_meta)
|
||||
g = self.__param_realloc_info.param_realloc_model_group[
|
||||
g = self.__param_realloc_info.param_realloc_model_cpu_group[
|
||||
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
|
||||
]
|
||||
dist.barrier(group=g)
|
||||
|
@ -1029,7 +1028,7 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__load_model(load_meta)
|
||||
# Remove the reallocated checkpoint.
|
||||
with constants.model_scope(to_model_name):
|
||||
dist.barrier(constants.parallelism_group())
|
||||
dist.barrier(constants.cpu_parallelism_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
shutil.rmtree(realloc_dir, ignore_errors=True)
|
||||
os.makedirs(realloc_dir, exist_ok=True)
|
||||
|
@ -1097,7 +1096,7 @@ class ModelWorker(worker_base.Worker):
|
|||
).is_symlink():
|
||||
os.unlink(save_root / fn)
|
||||
shutil.rmtree(save_dir, ignore_errors=True)
|
||||
dist.barrier(constants.parallelism_group())
|
||||
dist.barrier(constants.cpu_parallelism_group())
|
||||
self._interface.save(self._model, save_dir)
|
||||
# The `save` method of the interface may be empty.
|
||||
# We only save the backend state if the parameters have been indeed saved.
|
||||
|
|
Loading…
Reference in New Issue