PullRequest: 30 Change the topology order of vLLM for better locality of pipeline parallelism

Merge branch fw/topo of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/30?tab=comment

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* change non-training topo order
* .
* fix the dataloading bug during recover
* fix typo
This commit is contained in:
博惟 2025-03-12 17:26:50 +08:00
parent 3d8be914af
commit 56e4dd1f5d
13 changed files with 123 additions and 121 deletions

View File

@ -120,16 +120,14 @@ class ModelShardID:
:param pp_rank: The pipeline-model parallel rank.
:type pp_rank: int
:param topo: The 3D parallelism topology of this model.
:type topo: PipeModelDataParallelTopology
:type topo: ProcessTopology
"""
model_name: ModelName
dp_rank: int
mp_rank: int
pp_rank: int
topo: topology.PipeModelDataParallelTopology = dataclasses.field(
default_factory=lambda: topology.PipeModelDataParallelTopology(1, 1, 1)
)
topo: topology.ProcessTopology
def __post_init__(self):
assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0

View File

@ -118,7 +118,7 @@ class ModelWorker:
torch_cache_mysophobia: bool = False
# model_topos and worker_info will be configured automatically
model_rpcs: List[dfg.MFCDef] = None
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
model_topos: Dict[ModelName, topology.ProcessTopology] = None
msid2mwid: Dict[ModelShardID, int] = None
data_transfer_pairs: List[Tuple[str, str]] = None
sync_param_pairs: List[Tuple[str, str]] = None
@ -203,7 +203,7 @@ class MasterWorker:
# main components
n_model_workers: int
model_rpcs: List[dfg.MFCDef] = None
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
model_topos: Dict[ModelName, topology.ProcessTopology] = None
msid2mwid: Dict[ModelShardID | str, int] = None
data_transfer_pairs: List[Tuple[str, str]] = None
sync_param_pairs: List[Tuple[str, str]] = None
@ -387,7 +387,7 @@ class ExperimentConfig:
def _collect_topos(
self, model_names: List[ModelName]
) -> Dict[ModelName, topology.PipeModelDataParallelTopology]:
) -> Dict[ModelName, topology.ProcessTopology]:
model_topos = {}
model_allocations = {}
for model_name in model_names:

View File

@ -85,9 +85,6 @@ class OptimizerConfig:
:param warmup_steps_proportion: Proportion of total training steps
allocated for warming up. Should be in the interval [0.0, 1.0].
:type warmup_steps_proportion: float
:param offload: Whether to offload the optimizer to CPU. Only valid
for the DeepSpeed backend.
:type offload: bool
"""
type: str = dataclasses.field(
@ -216,12 +213,8 @@ class ModelTrainEvalConfig:
:type gradient_checkpointing: bool
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
:type bf16: bool
:param offload: Whether to offload model parameters to CPU. Only valid for the DeepSpeed backend.
:type offload: bool
:param parallel: Configuration for parallelism.
:type parallel: ParallelismConfig
:param zero_stage: Stage of ZeRO optimization. Should be one of 0, 1, 2, or 3.
:type zero_stage: int
:param optimizer: Configuration for the optimizer.
:type optimizer: Optional[OptimizerConfig]
:param init_critic_from_actor: Whether to initialize a critic/reward model from a saved LM checkpoint.
@ -235,11 +228,6 @@ class ModelTrainEvalConfig:
path: str = ""
gradient_checkpointing: bool = True
bf16: bool = False
offload: bool = False
zero_stage: int = dataclasses.field(
metadata={"choices": [0, 1, 2, 3]},
default=2,
)
optimizer: Optional[OptimizerConfig] = dataclasses.field(
default_factory=OptimizerConfig
)

View File

@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from realhf.api.core.config import ModelName
from realhf.api.core.system_api import ModelShardID
from realhf.base.topology import ParallelGrid, PipeModelDataParallelTopology
from realhf.base.topology import ParallelGrid, ProcessTopology
class GlobalMemoryBuffer:
@ -137,7 +137,6 @@ if cluster_spec.name == "wa180":
BASE_ENVIRONS.update(PPU_ENVIRONS)
elif cluster_spec.name == "na132":
# Specific environment variable for h800 cluster na132
# FIXME: change to general cases for open source repo
NV_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
@ -269,7 +268,7 @@ def set_self_group(pgroup):
def set_rank_mapping(
model_name: "ModelName",
topo: "PipeModelDataParallelTopology",
topo: "ProcessTopology",
msid2mwid: Optional[Dict["ModelShardID", int]] = None,
):
global _rank_mapping
@ -317,8 +316,8 @@ def gradient_accumulation_fusion() -> bool:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
return (
_grad_accum_fusion_available and grid().topology().gradient_accumulation_fusion
return _grad_accum_fusion_available and getattr(
grid().topology(), "gradient_accumulation_fusion", False
)
@ -327,7 +326,7 @@ def max_prompt_len() -> int:
def gradient_checkpointing() -> bool:
return grid().topology().gradient_checkpointing
return getattr(grid().topology(), "gradient_checkpointing", False)
def has_model_name(name: str) -> bool:

View File

@ -21,7 +21,11 @@ import torch.utils.data
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
from realhf.base.topology import ParallelGrid, PipeModelDataParallelTopology
from realhf.base.topology import (
DataPipeModelParallelTopology,
ParallelGrid,
PipeModelDataParallelTopology,
)
logger = logging.getLogger("testing")
@ -212,19 +216,28 @@ def init_global_constants(
gradient_checkpointing=True,
gradient_accumulation_fusion=False,
max_prompt_len=None,
is_train: bool = True,
):
model_name = model_name if model_name is not None else MODEL_NAME
if topo is None:
topo = PipeModelDataParallelTopology(
num_dp=num_dp,
num_mp=num_mp,
num_pp=num_pp,
sequence_parallel=sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
gradient_accumulation_fusion=gradient_accumulation_fusion,
max_prompt_len=max_prompt_len,
)
if is_train:
topo = PipeModelDataParallelTopology(
num_dp=num_dp,
num_mp=num_mp,
num_pp=num_pp,
sequence_parallel=sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
gradient_accumulation_fusion=gradient_accumulation_fusion,
max_prompt_len=max_prompt_len,
)
else:
topo = DataPipeModelParallelTopology(
num_dp=num_dp,
num_mp=num_mp,
num_pp=num_pp,
sequence_parallel=sequence_parallel,
)
ws = num_dp * num_mp * num_pp
else:
ws = topo.world_size()

View File

@ -65,15 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]:
return factors
class ProcessCoord(NamedTuple):
class PipeDataModelrocessCoord(NamedTuple):
pipe: int
data: int
model: int
class DataPipeModelrocessCoord(NamedTuple):
data: int
pipe: int
model: int
# Explicitly define these class to allow pickling.
PROCESS_COORD_REGISTRY = {
"pipe#data#model": ProcessCoord,
"pipe#data#model": PipeDataModelrocessCoord,
"data#pipe#model": DataPipeModelrocessCoord,
}
@ -320,7 +327,7 @@ def _prime_factors(N):
return primes
class PipeModelDataParallelTopology(ProcessTopology):
class PipeDataModelParallelTopology(ProcessTopology):
"""A topology for hybrid pipeline, model, and data parallelism."""
def __init__(
@ -341,6 +348,25 @@ class PipeModelDataParallelTopology(ProcessTopology):
self.gradient_accumulation_fusion = gradient_accumulation_fusion
class DataPipeModelParallelTopology(ProcessTopology):
"""A topology for hybrid data, pipeline, and tensor parallelism.
Note that DP is the most outer dimension. Used for inference only.
"""
def __init__(
self,
num_pp: int,
num_mp: int,
num_dp: int,
sequence_parallel: bool,
max_prompt_len: Optional[int] = None,
):
super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp])
self.sequence_parallel = sequence_parallel
self.max_prompt_len = max_prompt_len
class ParallelGrid:
"""Implements a grid object that stores the data parallel ranks
corresponding to each of the model parallel stages.
@ -365,7 +391,7 @@ class ParallelGrid:
def __init__(
self,
topology: PipeModelDataParallelTopology,
topology: ProcessTopology,
process_group: dist.ProcessGroup,
rank_mapping: Optional[Dict[int, int]] = None,
):
@ -558,7 +584,7 @@ class ParallelGrid:
transform = me._replace(pipe=stage_id, **kwargs)._asdict()
return self._topo.get_rank(**transform)
def topology(self) -> PipeModelDataParallelTopology:
def topology(self) -> ProcessTopology:
return self._topo
# MPU functions for DeepSpeed integration
@ -630,7 +656,7 @@ class ParallelGrid:
class FakeGrid:
"""Used for testing dynamic scheduling in none-GPU environment."""
def __init__(self, rank: int, topo: PipeModelDataParallelTopology):
def __init__(self, rank: int, topo: ProcessTopology):
self.rank = rank
self._topo = topo

View File

@ -25,10 +25,7 @@ def check_is_realhf_native_model_interface(name):
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
logger.warning(
"vLLM hybrid_train is enabled, but no training RPCs are found. Set it to False."
)
vllm.hybrid_train = False
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
if vllm.hybrid_train and not vllm.enforce_eager:
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
@ -45,18 +42,6 @@ def check_valid_optimizer(model: ModelTrainEvalConfig):
)
def check_valid_backend(role: str, model: ModelTrainEvalConfig):
if (model.offload or model.optimizer.offload) and model.backend != "deepspeed":
raise ValueError(
f"For model `{role}`, offload is only" " valid for the deepspeed backend."
)
if model.backend == "megatron" and model.zero_stage in [3]:
raise ValueError(
f"For model `{role}`, the Megatron backend"
" only supports zero stage 0, 1 or 2."
)
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
if not os.path.exists(model.path):
raise FileNotFoundError(

View File

@ -48,7 +48,6 @@ from realhf.api.quickstart.model import (
from realhf.base.cluster import spec as cluster_spec
from realhf.experiments.common.check import (
check_is_realhf_native_model_interface,
check_valid_backend,
check_valid_model_and_path,
check_valid_optimizer,
check_valid_parallel_batch_size,
@ -572,6 +571,7 @@ class CommonExperimentConfig(Experiment):
gradient_checkpointing=False,
max_prompt_len=(self.max_prompt_len),
gradient_accumulation_fusion=False,
is_train=False,
)
model_cfg = self.models[model_name.role]
global vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
@ -675,6 +675,7 @@ class CommonExperimentConfig(Experiment):
),
gradient_accumulation_fusion=(model_cfg.backend == "megatron")
and (model_cfg.type._class != "bailing"),
is_train=any(rpc.is_train() for rpc in rpcs),
)
if any(rpc.is_train() for rpc in rpcs):
@ -794,6 +795,5 @@ class CommonExperimentConfig(Experiment):
for alloc in rpc_allocs:
check_valid_parallel_batch_size(alloc)
for role, model in self.models.items():
check_valid_backend(role, model)
check_valid_model_and_path(role, model)
check_valid_optimizer(model)

View File

@ -25,7 +25,11 @@ from realhf.api.quickstart.model import (
parallelism_eq,
)
from realhf.base import logging
from realhf.base.topology import PipeModelDataParallelTopology
from realhf.base.topology import (
DataPipeModelParallelTopology,
PipeDataModelParallelTopology,
ProcessTopology,
)
logger = logging.getLogger("Experiment Common Utils", "benchmark")
@ -34,16 +38,25 @@ def get_topo(
parallel: ParallelismConfig,
gradient_checkpointing: bool,
gradient_accumulation_fusion: bool,
is_train: bool,
max_prompt_len: Optional[int] = None,
) -> PipeModelDataParallelTopology:
return PipeModelDataParallelTopology(
) -> ProcessTopology:
if is_train:
return PipeDataModelParallelTopology(
num_mp=parallel.model_parallel_size,
num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
max_prompt_len=max_prompt_len,
gradient_accumulation_fusion=gradient_accumulation_fusion,
)
return DataPipeModelParallelTopology(
num_mp=parallel.model_parallel_size,
num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
max_prompt_len=max_prompt_len,
gradient_accumulation_fusion=gradient_accumulation_fusion,
)
@ -58,49 +71,12 @@ def get_world_size(parallel: ParallelismConfig) -> int:
def make_train_backend_config(
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
):
if model_cfg.backend == "deepspeed":
return ModelBackendAbstraction(
"deepspeed",
args=dict(
optimizer_name="adam",
optimizer_config=dict(
lr=model_cfg.optimizer.lr,
weight_decay=model_cfg.optimizer.weight_decay,
eps=model_cfg.optimizer.eps,
betas=(
model_cfg.optimizer.beta1,
model_cfg.optimizer.beta2,
),
),
lr_scheduler_type=model_cfg.optimizer.lr_scheduler_type,
warmup_steps_proportion=model_cfg.optimizer.warmup_steps_proportion,
min_lr_ratio=model_cfg.optimizer.min_lr_ratio,
zero_stage=(
model_cfg.zero_stage
if parallel_cfg.pipeline_parallel_size == 1
else min(model_cfg.zero_stage, 1)
),
offload_optimizer_state=model_cfg.optimizer.offload,
offload_param=model_cfg.offload,
bf16=model_cfg.bf16,
),
)
elif model_cfg.backend == "megatron":
if model_cfg.optimizer.offload or model_cfg.offload:
raise ValueError("Offload is not supported in Megatron backend.")
if model_cfg.zero_stage == 3:
raise ValueError("Zero stage 3 is not supported in Megatron backend.")
if model_cfg.zero_stage == 2:
logger.warning(
"Megatron does not support ZeRO stage 2. Degenerates to stage 1."
)
model_cfg.zero_stage = 1
if model_cfg.backend == "megatron":
megatron_args: Dict[str, Any] = OmegaConf.to_container(model_cfg.megatron)
return ModelBackendAbstraction(
"megatron",
args=dict(
bf16=model_cfg.bf16,
zero_stage=model_cfg.zero_stage,
optimizer=model_cfg.optimizer,
**megatron_args,
),
@ -225,10 +201,12 @@ def resolve_rpc_hooks(
class AllocationType(enum.Enum):
DECOUPLED = 1
DECOUPLED_vLLM = 1
GLOBAL_HYBRID = 2
MANUAL = 3
SEARCH = 4
HEURISTIC = 4
SEARCH = 5
DECOUPLED_SGLANG = 6
@dataclasses.dataclass
@ -237,7 +215,16 @@ class AllocationMode:
parallel_strat: Dict[str, Dict[str, int]]
def is_decoupled(self):
return self.type_ == AllocationType.DECOUPLED
return self.type_ in [
AllocationType.DECOUPLED_vLLM,
AllocationType.DECOUPLED_SGLANG,
]
def is_decoupled_vllm(self):
return self.type_ == AllocationType.DECOUPLED_vLLM
def is_decoupled_sglang(self):
return self.type_ == AllocationType.DECOUPLED_SGLANG
def is_global_hybrid(self):
return self.type_ == AllocationType.GLOBAL_HYBRID
@ -246,13 +233,19 @@ class AllocationMode:
def from_str(cls, allocation_mode: str):
if allocation_mode == "manual":
return cls(AllocationType.MANUAL, None)
if allocation_mode == "heuristic":
return cls(AllocationType.HEURISTIC, None)
if allocation_mode == "search":
return cls(AllocationType.SEARCH, None)
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
if alloc_decoupled:
return cls(AllocationType.DECOUPLED, alloc_decoupled)
if "vllm" in allocation_mode:
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
elif "sglang" in allocation_mode:
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
if alloc_3d:
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
if alloc_hybrid:

View File

@ -26,7 +26,7 @@ class NCCLProcessGroupInfo:
def filter_match_mwids(
model_name: ModelName,
topo: topology.PipeModelDataParallelTopology,
topo: topology.ProcessTopology,
msid2mwid: Dict[ModelShardID, int],
**conditions,
) -> List[int]:
@ -49,7 +49,7 @@ def setup_global_comm(
expr_name: str,
trial_name: str,
worker_index: int,
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
backend: str = "nccl",
) -> NCCLProcessGroupInfo:

View File

@ -152,8 +152,8 @@ def _assign_src_to_dsts(
def _create_param_realloc_groups(
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
src: ModelName,
dst: ModelName,
msid2mwid: Dict[ModelShardID, int],
@ -262,7 +262,7 @@ def _create_param_realloc_groups(
def setup_param_realloc(
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
param_realloc_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
) -> ParamReallocInfo:
@ -341,8 +341,8 @@ class ReparallelizeReceiverStep:
def _derive_reparallelize_comm_plan(
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
from_model_config: model_api.ReaLModelConfig,
to_model_config: model_api.ReaLModelConfig,
pg_info: ParamReallocInfo,

View File

@ -562,8 +562,8 @@ class ReaLModel(nn.Module):
self,
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
to_model_config: model_api.ReaLModelConfig,
pg_info: NCCLProcessGroupInfo,
from_model_config: None | model_api.ReaLModelConfig = None,
@ -638,8 +638,8 @@ class ReaLModel(nn.Module):
self,
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
to_model_config: model_api.ReaLModelConfig,
pg_info: NCCLProcessGroupInfo,
) -> Tuple[nn.ModuleList, torch.Tensor, torch.Tensor]:

View File

@ -39,8 +39,8 @@ def compute_cost(
world_size: int,
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
model_config: ReaLModelConfig,
bw: float, # Gbps
set_interval_cost: float,
@ -161,10 +161,10 @@ def dump_table(
):
world_size = max(a, b)
from_topo = topology.PipeModelDataParallelTopology(
from_topo = topology.PipeDataModelParallelTopology(
*from_pp_mp_dp, False, False
)
to_topo = topology.PipeModelDataParallelTopology(*to_pp_mp_dp, False, False)
to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False)
assert world_size >= from_topo.world_size()
assert world_size >= to_topo.world_size()