mirror of https://github.com/inclusionAI/AReaL
PullRequest: 30 Change the topology order of vLLM for better locality of pipeline parallelism
Merge branch fw/topo of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/30?tab=comment Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * change non-training topo order * . * fix the dataloading bug during recover * fix typo
This commit is contained in:
parent
3d8be914af
commit
56e4dd1f5d
|
@ -120,16 +120,14 @@ class ModelShardID:
|
|||
:param pp_rank: The pipeline-model parallel rank.
|
||||
:type pp_rank: int
|
||||
:param topo: The 3D parallelism topology of this model.
|
||||
:type topo: PipeModelDataParallelTopology
|
||||
:type topo: ProcessTopology
|
||||
"""
|
||||
|
||||
model_name: ModelName
|
||||
dp_rank: int
|
||||
mp_rank: int
|
||||
pp_rank: int
|
||||
topo: topology.PipeModelDataParallelTopology = dataclasses.field(
|
||||
default_factory=lambda: topology.PipeModelDataParallelTopology(1, 1, 1)
|
||||
)
|
||||
topo: topology.ProcessTopology
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0
|
||||
|
|
|
@ -118,7 +118,7 @@ class ModelWorker:
|
|||
torch_cache_mysophobia: bool = False
|
||||
# model_topos and worker_info will be configured automatically
|
||||
model_rpcs: List[dfg.MFCDef] = None
|
||||
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
|
||||
model_topos: Dict[ModelName, topology.ProcessTopology] = None
|
||||
msid2mwid: Dict[ModelShardID, int] = None
|
||||
data_transfer_pairs: List[Tuple[str, str]] = None
|
||||
sync_param_pairs: List[Tuple[str, str]] = None
|
||||
|
@ -203,7 +203,7 @@ class MasterWorker:
|
|||
# main components
|
||||
n_model_workers: int
|
||||
model_rpcs: List[dfg.MFCDef] = None
|
||||
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
|
||||
model_topos: Dict[ModelName, topology.ProcessTopology] = None
|
||||
msid2mwid: Dict[ModelShardID | str, int] = None
|
||||
data_transfer_pairs: List[Tuple[str, str]] = None
|
||||
sync_param_pairs: List[Tuple[str, str]] = None
|
||||
|
@ -387,7 +387,7 @@ class ExperimentConfig:
|
|||
|
||||
def _collect_topos(
|
||||
self, model_names: List[ModelName]
|
||||
) -> Dict[ModelName, topology.PipeModelDataParallelTopology]:
|
||||
) -> Dict[ModelName, topology.ProcessTopology]:
|
||||
model_topos = {}
|
||||
model_allocations = {}
|
||||
for model_name in model_names:
|
||||
|
|
|
@ -85,9 +85,6 @@ class OptimizerConfig:
|
|||
:param warmup_steps_proportion: Proportion of total training steps
|
||||
allocated for warming up. Should be in the interval [0.0, 1.0].
|
||||
:type warmup_steps_proportion: float
|
||||
:param offload: Whether to offload the optimizer to CPU. Only valid
|
||||
for the DeepSpeed backend.
|
||||
:type offload: bool
|
||||
"""
|
||||
|
||||
type: str = dataclasses.field(
|
||||
|
@ -216,12 +213,8 @@ class ModelTrainEvalConfig:
|
|||
:type gradient_checkpointing: bool
|
||||
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
|
||||
:type bf16: bool
|
||||
:param offload: Whether to offload model parameters to CPU. Only valid for the DeepSpeed backend.
|
||||
:type offload: bool
|
||||
:param parallel: Configuration for parallelism.
|
||||
:type parallel: ParallelismConfig
|
||||
:param zero_stage: Stage of ZeRO optimization. Should be one of 0, 1, 2, or 3.
|
||||
:type zero_stage: int
|
||||
:param optimizer: Configuration for the optimizer.
|
||||
:type optimizer: Optional[OptimizerConfig]
|
||||
:param init_critic_from_actor: Whether to initialize a critic/reward model from a saved LM checkpoint.
|
||||
|
@ -235,11 +228,6 @@ class ModelTrainEvalConfig:
|
|||
path: str = ""
|
||||
gradient_checkpointing: bool = True
|
||||
bf16: bool = False
|
||||
offload: bool = False
|
||||
zero_stage: int = dataclasses.field(
|
||||
metadata={"choices": [0, 1, 2, 3]},
|
||||
default=2,
|
||||
)
|
||||
optimizer: Optional[OptimizerConfig] = dataclasses.field(
|
||||
default_factory=OptimizerConfig
|
||||
)
|
||||
|
|
|
@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
if TYPE_CHECKING:
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.system_api import ModelShardID
|
||||
from realhf.base.topology import ParallelGrid, PipeModelDataParallelTopology
|
||||
from realhf.base.topology import ParallelGrid, ProcessTopology
|
||||
|
||||
|
||||
class GlobalMemoryBuffer:
|
||||
|
@ -137,7 +137,6 @@ if cluster_spec.name == "wa180":
|
|||
BASE_ENVIRONS.update(PPU_ENVIRONS)
|
||||
elif cluster_spec.name == "na132":
|
||||
# Specific environment variable for h800 cluster na132
|
||||
# FIXME: change to general cases for open source repo
|
||||
NV_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
|
@ -269,7 +268,7 @@ def set_self_group(pgroup):
|
|||
|
||||
def set_rank_mapping(
|
||||
model_name: "ModelName",
|
||||
topo: "PipeModelDataParallelTopology",
|
||||
topo: "ProcessTopology",
|
||||
msid2mwid: Optional[Dict["ModelShardID", int]] = None,
|
||||
):
|
||||
global _rank_mapping
|
||||
|
@ -317,8 +316,8 @@ def gradient_accumulation_fusion() -> bool:
|
|||
import fused_weight_gradient_mlp_cuda
|
||||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
return (
|
||||
_grad_accum_fusion_available and grid().topology().gradient_accumulation_fusion
|
||||
return _grad_accum_fusion_available and getattr(
|
||||
grid().topology(), "gradient_accumulation_fusion", False
|
||||
)
|
||||
|
||||
|
||||
|
@ -327,7 +326,7 @@ def max_prompt_len() -> int:
|
|||
|
||||
|
||||
def gradient_checkpointing() -> bool:
|
||||
return grid().topology().gradient_checkpointing
|
||||
return getattr(grid().topology(), "gradient_checkpointing", False)
|
||||
|
||||
|
||||
def has_model_name(name: str) -> bool:
|
||||
|
|
|
@ -21,7 +21,11 @@ import torch.utils.data
|
|||
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
|
||||
from realhf.base.topology import ParallelGrid, PipeModelDataParallelTopology
|
||||
from realhf.base.topology import (
|
||||
DataPipeModelParallelTopology,
|
||||
ParallelGrid,
|
||||
PipeModelDataParallelTopology,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("testing")
|
||||
|
||||
|
@ -212,19 +216,28 @@ def init_global_constants(
|
|||
gradient_checkpointing=True,
|
||||
gradient_accumulation_fusion=False,
|
||||
max_prompt_len=None,
|
||||
is_train: bool = True,
|
||||
):
|
||||
model_name = model_name if model_name is not None else MODEL_NAME
|
||||
|
||||
if topo is None:
|
||||
topo = PipeModelDataParallelTopology(
|
||||
num_dp=num_dp,
|
||||
num_mp=num_mp,
|
||||
num_pp=num_pp,
|
||||
sequence_parallel=sequence_parallel,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
||||
max_prompt_len=max_prompt_len,
|
||||
)
|
||||
if is_train:
|
||||
topo = PipeModelDataParallelTopology(
|
||||
num_dp=num_dp,
|
||||
num_mp=num_mp,
|
||||
num_pp=num_pp,
|
||||
sequence_parallel=sequence_parallel,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
||||
max_prompt_len=max_prompt_len,
|
||||
)
|
||||
else:
|
||||
topo = DataPipeModelParallelTopology(
|
||||
num_dp=num_dp,
|
||||
num_mp=num_mp,
|
||||
num_pp=num_pp,
|
||||
sequence_parallel=sequence_parallel,
|
||||
)
|
||||
ws = num_dp * num_mp * num_pp
|
||||
else:
|
||||
ws = topo.world_size()
|
||||
|
|
|
@ -65,15 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]:
|
|||
return factors
|
||||
|
||||
|
||||
class ProcessCoord(NamedTuple):
|
||||
class PipeDataModelrocessCoord(NamedTuple):
|
||||
pipe: int
|
||||
data: int
|
||||
model: int
|
||||
|
||||
|
||||
class DataPipeModelrocessCoord(NamedTuple):
|
||||
data: int
|
||||
pipe: int
|
||||
model: int
|
||||
|
||||
|
||||
# Explicitly define these class to allow pickling.
|
||||
PROCESS_COORD_REGISTRY = {
|
||||
"pipe#data#model": ProcessCoord,
|
||||
"pipe#data#model": PipeDataModelrocessCoord,
|
||||
"data#pipe#model": DataPipeModelrocessCoord,
|
||||
}
|
||||
|
||||
|
||||
|
@ -320,7 +327,7 @@ def _prime_factors(N):
|
|||
return primes
|
||||
|
||||
|
||||
class PipeModelDataParallelTopology(ProcessTopology):
|
||||
class PipeDataModelParallelTopology(ProcessTopology):
|
||||
"""A topology for hybrid pipeline, model, and data parallelism."""
|
||||
|
||||
def __init__(
|
||||
|
@ -341,6 +348,25 @@ class PipeModelDataParallelTopology(ProcessTopology):
|
|||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
|
||||
|
||||
class DataPipeModelParallelTopology(ProcessTopology):
|
||||
"""A topology for hybrid data, pipeline, and tensor parallelism.
|
||||
|
||||
Note that DP is the most outer dimension. Used for inference only.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_pp: int,
|
||||
num_mp: int,
|
||||
num_dp: int,
|
||||
sequence_parallel: bool,
|
||||
max_prompt_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp])
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.max_prompt_len = max_prompt_len
|
||||
|
||||
|
||||
class ParallelGrid:
|
||||
"""Implements a grid object that stores the data parallel ranks
|
||||
corresponding to each of the model parallel stages.
|
||||
|
@ -365,7 +391,7 @@ class ParallelGrid:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
topology: PipeModelDataParallelTopology,
|
||||
topology: ProcessTopology,
|
||||
process_group: dist.ProcessGroup,
|
||||
rank_mapping: Optional[Dict[int, int]] = None,
|
||||
):
|
||||
|
@ -558,7 +584,7 @@ class ParallelGrid:
|
|||
transform = me._replace(pipe=stage_id, **kwargs)._asdict()
|
||||
return self._topo.get_rank(**transform)
|
||||
|
||||
def topology(self) -> PipeModelDataParallelTopology:
|
||||
def topology(self) -> ProcessTopology:
|
||||
return self._topo
|
||||
|
||||
# MPU functions for DeepSpeed integration
|
||||
|
@ -630,7 +656,7 @@ class ParallelGrid:
|
|||
class FakeGrid:
|
||||
"""Used for testing dynamic scheduling in none-GPU environment."""
|
||||
|
||||
def __init__(self, rank: int, topo: PipeModelDataParallelTopology):
|
||||
def __init__(self, rank: int, topo: ProcessTopology):
|
||||
self.rank = rank
|
||||
self._topo = topo
|
||||
|
||||
|
|
|
@ -25,10 +25,7 @@ def check_is_realhf_native_model_interface(name):
|
|||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
|
||||
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
|
||||
logger.warning(
|
||||
"vLLM hybrid_train is enabled, but no training RPCs are found. Set it to False."
|
||||
)
|
||||
vllm.hybrid_train = False
|
||||
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
|
||||
if vllm.hybrid_train and not vllm.enforce_eager:
|
||||
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
|
||||
|
||||
|
@ -45,18 +42,6 @@ def check_valid_optimizer(model: ModelTrainEvalConfig):
|
|||
)
|
||||
|
||||
|
||||
def check_valid_backend(role: str, model: ModelTrainEvalConfig):
|
||||
if (model.offload or model.optimizer.offload) and model.backend != "deepspeed":
|
||||
raise ValueError(
|
||||
f"For model `{role}`, offload is only" " valid for the deepspeed backend."
|
||||
)
|
||||
if model.backend == "megatron" and model.zero_stage in [3]:
|
||||
raise ValueError(
|
||||
f"For model `{role}`, the Megatron backend"
|
||||
" only supports zero stage 0, 1 or 2."
|
||||
)
|
||||
|
||||
|
||||
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
|
||||
if not os.path.exists(model.path):
|
||||
raise FileNotFoundError(
|
||||
|
|
|
@ -48,7 +48,6 @@ from realhf.api.quickstart.model import (
|
|||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.experiments.common.check import (
|
||||
check_is_realhf_native_model_interface,
|
||||
check_valid_backend,
|
||||
check_valid_model_and_path,
|
||||
check_valid_optimizer,
|
||||
check_valid_parallel_batch_size,
|
||||
|
@ -572,6 +571,7 @@ class CommonExperimentConfig(Experiment):
|
|||
gradient_checkpointing=False,
|
||||
max_prompt_len=(self.max_prompt_len),
|
||||
gradient_accumulation_fusion=False,
|
||||
is_train=False,
|
||||
)
|
||||
model_cfg = self.models[model_name.role]
|
||||
global vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
|
@ -675,6 +675,7 @@ class CommonExperimentConfig(Experiment):
|
|||
),
|
||||
gradient_accumulation_fusion=(model_cfg.backend == "megatron")
|
||||
and (model_cfg.type._class != "bailing"),
|
||||
is_train=any(rpc.is_train() for rpc in rpcs),
|
||||
)
|
||||
|
||||
if any(rpc.is_train() for rpc in rpcs):
|
||||
|
@ -794,6 +795,5 @@ class CommonExperimentConfig(Experiment):
|
|||
for alloc in rpc_allocs:
|
||||
check_valid_parallel_batch_size(alloc)
|
||||
for role, model in self.models.items():
|
||||
check_valid_backend(role, model)
|
||||
check_valid_model_and_path(role, model)
|
||||
check_valid_optimizer(model)
|
||||
|
|
|
@ -25,7 +25,11 @@ from realhf.api.quickstart.model import (
|
|||
parallelism_eq,
|
||||
)
|
||||
from realhf.base import logging
|
||||
from realhf.base.topology import PipeModelDataParallelTopology
|
||||
from realhf.base.topology import (
|
||||
DataPipeModelParallelTopology,
|
||||
PipeDataModelParallelTopology,
|
||||
ProcessTopology,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("Experiment Common Utils", "benchmark")
|
||||
|
||||
|
@ -34,16 +38,25 @@ def get_topo(
|
|||
parallel: ParallelismConfig,
|
||||
gradient_checkpointing: bool,
|
||||
gradient_accumulation_fusion: bool,
|
||||
is_train: bool,
|
||||
max_prompt_len: Optional[int] = None,
|
||||
) -> PipeModelDataParallelTopology:
|
||||
return PipeModelDataParallelTopology(
|
||||
) -> ProcessTopology:
|
||||
if is_train:
|
||||
return PipeDataModelParallelTopology(
|
||||
num_mp=parallel.model_parallel_size,
|
||||
num_pp=parallel.pipeline_parallel_size,
|
||||
num_dp=parallel.data_parallel_size,
|
||||
sequence_parallel=parallel.use_sequence_parallel,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
max_prompt_len=max_prompt_len,
|
||||
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
||||
)
|
||||
return DataPipeModelParallelTopology(
|
||||
num_mp=parallel.model_parallel_size,
|
||||
num_pp=parallel.pipeline_parallel_size,
|
||||
num_dp=parallel.data_parallel_size,
|
||||
sequence_parallel=parallel.use_sequence_parallel,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
max_prompt_len=max_prompt_len,
|
||||
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
||||
)
|
||||
|
||||
|
||||
|
@ -58,49 +71,12 @@ def get_world_size(parallel: ParallelismConfig) -> int:
|
|||
def make_train_backend_config(
|
||||
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
|
||||
):
|
||||
if model_cfg.backend == "deepspeed":
|
||||
return ModelBackendAbstraction(
|
||||
"deepspeed",
|
||||
args=dict(
|
||||
optimizer_name="adam",
|
||||
optimizer_config=dict(
|
||||
lr=model_cfg.optimizer.lr,
|
||||
weight_decay=model_cfg.optimizer.weight_decay,
|
||||
eps=model_cfg.optimizer.eps,
|
||||
betas=(
|
||||
model_cfg.optimizer.beta1,
|
||||
model_cfg.optimizer.beta2,
|
||||
),
|
||||
),
|
||||
lr_scheduler_type=model_cfg.optimizer.lr_scheduler_type,
|
||||
warmup_steps_proportion=model_cfg.optimizer.warmup_steps_proportion,
|
||||
min_lr_ratio=model_cfg.optimizer.min_lr_ratio,
|
||||
zero_stage=(
|
||||
model_cfg.zero_stage
|
||||
if parallel_cfg.pipeline_parallel_size == 1
|
||||
else min(model_cfg.zero_stage, 1)
|
||||
),
|
||||
offload_optimizer_state=model_cfg.optimizer.offload,
|
||||
offload_param=model_cfg.offload,
|
||||
bf16=model_cfg.bf16,
|
||||
),
|
||||
)
|
||||
elif model_cfg.backend == "megatron":
|
||||
if model_cfg.optimizer.offload or model_cfg.offload:
|
||||
raise ValueError("Offload is not supported in Megatron backend.")
|
||||
if model_cfg.zero_stage == 3:
|
||||
raise ValueError("Zero stage 3 is not supported in Megatron backend.")
|
||||
if model_cfg.zero_stage == 2:
|
||||
logger.warning(
|
||||
"Megatron does not support ZeRO stage 2. Degenerates to stage 1."
|
||||
)
|
||||
model_cfg.zero_stage = 1
|
||||
if model_cfg.backend == "megatron":
|
||||
megatron_args: Dict[str, Any] = OmegaConf.to_container(model_cfg.megatron)
|
||||
return ModelBackendAbstraction(
|
||||
"megatron",
|
||||
args=dict(
|
||||
bf16=model_cfg.bf16,
|
||||
zero_stage=model_cfg.zero_stage,
|
||||
optimizer=model_cfg.optimizer,
|
||||
**megatron_args,
|
||||
),
|
||||
|
@ -225,10 +201,12 @@ def resolve_rpc_hooks(
|
|||
|
||||
|
||||
class AllocationType(enum.Enum):
|
||||
DECOUPLED = 1
|
||||
DECOUPLED_vLLM = 1
|
||||
GLOBAL_HYBRID = 2
|
||||
MANUAL = 3
|
||||
SEARCH = 4
|
||||
HEURISTIC = 4
|
||||
SEARCH = 5
|
||||
DECOUPLED_SGLANG = 6
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -237,7 +215,16 @@ class AllocationMode:
|
|||
parallel_strat: Dict[str, Dict[str, int]]
|
||||
|
||||
def is_decoupled(self):
|
||||
return self.type_ == AllocationType.DECOUPLED
|
||||
return self.type_ in [
|
||||
AllocationType.DECOUPLED_vLLM,
|
||||
AllocationType.DECOUPLED_SGLANG,
|
||||
]
|
||||
|
||||
def is_decoupled_vllm(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_vLLM
|
||||
|
||||
def is_decoupled_sglang(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_SGLANG
|
||||
|
||||
def is_global_hybrid(self):
|
||||
return self.type_ == AllocationType.GLOBAL_HYBRID
|
||||
|
@ -246,13 +233,19 @@ class AllocationMode:
|
|||
def from_str(cls, allocation_mode: str):
|
||||
if allocation_mode == "manual":
|
||||
return cls(AllocationType.MANUAL, None)
|
||||
if allocation_mode == "heuristic":
|
||||
return cls(AllocationType.HEURISTIC, None)
|
||||
if allocation_mode == "search":
|
||||
return cls(AllocationType.SEARCH, None)
|
||||
|
||||
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
|
||||
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
|
||||
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
|
||||
if alloc_decoupled:
|
||||
return cls(AllocationType.DECOUPLED, alloc_decoupled)
|
||||
if "vllm" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
|
||||
elif "sglang" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
|
||||
if alloc_3d:
|
||||
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
|
||||
if alloc_hybrid:
|
||||
|
|
|
@ -26,7 +26,7 @@ class NCCLProcessGroupInfo:
|
|||
|
||||
def filter_match_mwids(
|
||||
model_name: ModelName,
|
||||
topo: topology.PipeModelDataParallelTopology,
|
||||
topo: topology.ProcessTopology,
|
||||
msid2mwid: Dict[ModelShardID, int],
|
||||
**conditions,
|
||||
) -> List[int]:
|
||||
|
@ -49,7 +49,7 @@ def setup_global_comm(
|
|||
expr_name: str,
|
||||
trial_name: str,
|
||||
worker_index: int,
|
||||
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
|
||||
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
|
||||
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
||||
backend: str = "nccl",
|
||||
) -> NCCLProcessGroupInfo:
|
||||
|
|
|
@ -152,8 +152,8 @@ def _assign_src_to_dsts(
|
|||
|
||||
|
||||
def _create_param_realloc_groups(
|
||||
from_topo: topology.PipeModelDataParallelTopology,
|
||||
to_topo: topology.PipeModelDataParallelTopology,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
src: ModelName,
|
||||
dst: ModelName,
|
||||
msid2mwid: Dict[ModelShardID, int],
|
||||
|
@ -262,7 +262,7 @@ def _create_param_realloc_groups(
|
|||
|
||||
|
||||
def setup_param_realloc(
|
||||
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
|
||||
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
|
||||
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
||||
param_realloc_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
|
||||
) -> ParamReallocInfo:
|
||||
|
@ -341,8 +341,8 @@ class ReparallelizeReceiverStep:
|
|||
def _derive_reparallelize_comm_plan(
|
||||
from_model_name: ModelName,
|
||||
to_model_name: ModelName,
|
||||
from_topo: topology.PipeModelDataParallelTopology,
|
||||
to_topo: topology.PipeModelDataParallelTopology,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
from_model_config: model_api.ReaLModelConfig,
|
||||
to_model_config: model_api.ReaLModelConfig,
|
||||
pg_info: ParamReallocInfo,
|
||||
|
|
|
@ -562,8 +562,8 @@ class ReaLModel(nn.Module):
|
|||
self,
|
||||
from_model_name: ModelName,
|
||||
to_model_name: ModelName,
|
||||
from_topo: topology.PipeModelDataParallelTopology,
|
||||
to_topo: topology.PipeModelDataParallelTopology,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
to_model_config: model_api.ReaLModelConfig,
|
||||
pg_info: NCCLProcessGroupInfo,
|
||||
from_model_config: None | model_api.ReaLModelConfig = None,
|
||||
|
@ -638,8 +638,8 @@ class ReaLModel(nn.Module):
|
|||
self,
|
||||
from_model_name: ModelName,
|
||||
to_model_name: ModelName,
|
||||
from_topo: topology.PipeModelDataParallelTopology,
|
||||
to_topo: topology.PipeModelDataParallelTopology,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
to_model_config: model_api.ReaLModelConfig,
|
||||
pg_info: NCCLProcessGroupInfo,
|
||||
) -> Tuple[nn.ModuleList, torch.Tensor, torch.Tensor]:
|
||||
|
|
|
@ -39,8 +39,8 @@ def compute_cost(
|
|||
world_size: int,
|
||||
from_model_name: ModelName,
|
||||
to_model_name: ModelName,
|
||||
from_topo: topology.PipeModelDataParallelTopology,
|
||||
to_topo: topology.PipeModelDataParallelTopology,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
model_config: ReaLModelConfig,
|
||||
bw: float, # Gbps
|
||||
set_interval_cost: float,
|
||||
|
@ -161,10 +161,10 @@ def dump_table(
|
|||
):
|
||||
world_size = max(a, b)
|
||||
|
||||
from_topo = topology.PipeModelDataParallelTopology(
|
||||
from_topo = topology.PipeDataModelParallelTopology(
|
||||
*from_pp_mp_dp, False, False
|
||||
)
|
||||
to_topo = topology.PipeModelDataParallelTopology(*to_pp_mp_dp, False, False)
|
||||
to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False)
|
||||
assert world_size >= from_topo.world_size()
|
||||
assert world_size >= to_topo.world_size()
|
||||
|
||||
|
|
Loading…
Reference in New Issue