mirror of https://github.com/inclusionAI/AReaL
PullRequest: 7 Support the key-value allocation when using decoupled vLLM generation.
Merge branch fw/vllm-key-value-alloc of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/7?tab=commit Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * remove pring
This commit is contained in:
parent
e7c4a49adc
commit
ceee49454a
|
@ -15,6 +15,7 @@ from typing import *
|
|||
import numpy as np
|
||||
import transformers
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
from transformers.utils import is_accelerate_available
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
|
@ -58,9 +59,7 @@ from realhf.experiments.common.check import (
|
|||
check_valid_vllm,
|
||||
)
|
||||
from realhf.experiments.common.utils import (
|
||||
extract_decoupled_vllm_train_allocation,
|
||||
extract_key_value_allocation,
|
||||
extract_symmetric_allocation,
|
||||
AllocationMode,
|
||||
get_topo,
|
||||
make_inf_backend_config,
|
||||
make_train_backend_config,
|
||||
|
@ -111,10 +110,6 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
- ``heuristic``\: Allocate resources and configure parallel strategies using heuristic strategies obtained from a search.
|
||||
|
||||
- ``pipe_data``\: Identical parallelization (like DSChat) with pipe+data parallelism. For a world size under 8, only data parallelism will be used.
|
||||
|
||||
- ``pipe_model``\: Identical parallelization (like DSChat) with pipe+model parallelism. For a world size under 8, only tensor-model parallelism will be used.
|
||||
|
||||
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
|
||||
|
||||
- A regex pattern like ``vllm.{IdentPara}+{IdentPara}``\: Decoupled generation (vLLM) and training allocations with correspnding identical parallelization strategies. Note that the pipeline parallel degree of vLLM can only be 1.
|
||||
|
@ -204,7 +199,7 @@ class CommonExperimentConfig(Experiment):
|
|||
recover_retries: int = 1
|
||||
recover_after: int = 10
|
||||
ignore_worker_error: bool = False
|
||||
allocation_mode: str = "pipe_model"
|
||||
allocation_mode: str = ""
|
||||
allocation_use_cache: bool = False
|
||||
n_nodes: int = 1
|
||||
n_gpus_per_node: int = cluster_spec.n_gpus_per_node
|
||||
|
@ -358,6 +353,8 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
self.__check_legal_allocation_options()
|
||||
|
||||
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
|
||||
|
||||
rpcs = self.rpcs
|
||||
if self.allocation_mode == "search":
|
||||
# assert self.mode == "slurm"
|
||||
|
@ -372,90 +369,72 @@ class CommonExperimentConfig(Experiment):
|
|||
break
|
||||
else:
|
||||
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
|
||||
elif (
|
||||
self.allocation_mode == "pipe_data"
|
||||
or self.allocation_mode == "pipe_model"
|
||||
or extract_symmetric_allocation(self.allocation_mode)
|
||||
):
|
||||
if self.allocation_mode == "pipe_data":
|
||||
dp, pp, mp = self.n_gpus_per_node, self.n_nodes, 1
|
||||
elif self.allocation_mode == "pipe_model":
|
||||
dp, pp, mp = 1, self.n_nodes, self.n_gpus_per_node
|
||||
else:
|
||||
para = extract_symmetric_allocation(self.allocation_mode)
|
||||
dp, pp, mp = para["d"], para["p"], para["m"]
|
||||
if dp * pp * mp != self.n_nodes * self.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
"The multiplication of 3D parallel degrees "
|
||||
"does not equal to the number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
rpc_allocs: List[RPCAllocation] = [
|
||||
RPCAllocation(
|
||||
rpc=rpc,
|
||||
device_mesh=self.global_device_mesh,
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=dp,
|
||||
pipeline_parallel_size=pp,
|
||||
model_parallel_size=mp,
|
||||
use_sequence_parallel=(
|
||||
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
and mp > 1
|
||||
),
|
||||
),
|
||||
)
|
||||
for rpc in rpcs.values()
|
||||
]
|
||||
elif extract_decoupled_vllm_train_allocation(self.allocation_mode):
|
||||
para = extract_decoupled_vllm_train_allocation(self.allocation_mode)
|
||||
dp, pp, mp = para["d"], para["p"], para["m"]
|
||||
vdp, vpp, vmp = para["vllm.d"], para["vllm.p"], para["vllm.m"]
|
||||
vllm_world_size = vdp * vpp * vmp
|
||||
if dp * pp * mp + vdp * vpp * vmp != self.n_nodes * self.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
"The multiplication of 3D parallel degrees "
|
||||
"does not equal to the number of gpus. "
|
||||
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
|
||||
"so their summation should be equal to the total number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={vdp}, vllm.pp={vpp}, vllm.mp={vmp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
elif self._allocation_mode.is_decoupled():
|
||||
paras = self._allocation_mode.parallel_strat
|
||||
|
||||
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
|
||||
gen_world_size = gdp * gpp * gmp
|
||||
assert (
|
||||
vllm_world_size < self.n_gpus_per_node
|
||||
or vllm_world_size % self.n_gpus_per_node == 0
|
||||
gen_world_size < self.n_gpus_per_node
|
||||
or gen_world_size % self.n_gpus_per_node == 0
|
||||
)
|
||||
vllm_device_mesh, train_device_mesh = self.global_device_mesh.split(
|
||||
vllm_world_size
|
||||
gen_device_mesh, train_device_mesh = self.global_device_mesh.split(
|
||||
gen_world_size
|
||||
)
|
||||
|
||||
self.vllm_device_mesh = vllm_device_mesh
|
||||
self.gen_device_mesh = gen_device_mesh
|
||||
self.train_device_mesh = train_device_mesh
|
||||
|
||||
rpc_allocs = []
|
||||
flag = False
|
||||
for rpc in rpcs.values():
|
||||
if rpc.interface_type == ModelInterfaceType.GENERATE:
|
||||
if vpp != 1:
|
||||
if rpc.is_generate():
|
||||
if gpp != 1:
|
||||
raise NotImplementedError(
|
||||
"vllm pipeline parallel is not supported yet."
|
||||
"vllm/sglang pipeline parallel is not supported yet."
|
||||
)
|
||||
if flag:
|
||||
raise NotImplementedError(
|
||||
"vllm does not support two generation RPCs for now."
|
||||
"vllm/sglang does not support two generation RPCs for now."
|
||||
)
|
||||
alloc = RPCAllocation(
|
||||
rpc=rpc,
|
||||
device_mesh=vllm_device_mesh,
|
||||
device_mesh=gen_device_mesh,
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=vdp,
|
||||
pipeline_parallel_size=vpp,
|
||||
model_parallel_size=vmp,
|
||||
data_parallel_size=gdp,
|
||||
pipeline_parallel_size=gpp,
|
||||
model_parallel_size=gmp,
|
||||
use_sequence_parallel=False,
|
||||
),
|
||||
)
|
||||
flag = True
|
||||
else:
|
||||
rpc_name = rpc.name
|
||||
if rpc_name in paras:
|
||||
dp, pp, mp = (
|
||||
paras[rpc_name]["d"],
|
||||
paras[rpc_name]["p"],
|
||||
paras[rpc_name]["m"],
|
||||
)
|
||||
else:
|
||||
if "*" not in paras:
|
||||
raise ValueError(
|
||||
f"RPC {rpc_name} parallel strategy not given, "
|
||||
"expect a `*` to specify the default parallel strategy."
|
||||
)
|
||||
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
|
||||
if (
|
||||
dp * pp * mp + gdp * gpp * gmp
|
||||
!= self.n_nodes * self.n_gpus_per_node
|
||||
):
|
||||
raise ValueError(
|
||||
"The multiplication of 3D parallel degrees "
|
||||
"does not equal to the number of gpus. "
|
||||
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
|
||||
"so their summation should be equal to the total number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={gdp}, vllm.pp={gpp}, vllm.mp={gmp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
alloc = RPCAllocation(
|
||||
rpc=rpc,
|
||||
device_mesh=train_device_mesh,
|
||||
|
@ -472,10 +451,10 @@ class CommonExperimentConfig(Experiment):
|
|||
rpc_allocs.append(alloc)
|
||||
if not flag:
|
||||
raise ValueError(
|
||||
"No generation RPC found. Please use the allocation mode without vllm."
|
||||
"No generation RPC found. Please use the hybrid train allocation mode."
|
||||
)
|
||||
elif extract_key_value_allocation(self.allocation_mode):
|
||||
paras = extract_key_value_allocation(self.allocation_mode)
|
||||
elif self._allocation_mode.is_global_hybrid():
|
||||
paras = self._allocation_mode.parallel_strat
|
||||
rpc_allocs = []
|
||||
for rpc_name, rpc in self.rpcs.items():
|
||||
if rpc_name in paras:
|
||||
|
@ -558,8 +537,8 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
# vLLM enabled model worker, shortcut case
|
||||
if (
|
||||
extract_decoupled_vllm_train_allocation(self.allocation_mode)
|
||||
and self.vllm_device_mesh.mapping[i, j]
|
||||
self._allocation_mode.is_decoupled()
|
||||
and self.gen_device_mesh.mapping[i, j]
|
||||
):
|
||||
gen_rpc_alloc = next(
|
||||
alloc
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
import re
|
||||
from typing import *
|
||||
|
@ -224,58 +226,94 @@ def resolve_rpc_hooks(
|
|||
logger.info(f"Add offload hook for rpc {rpc.name} for role {rpc.role}")
|
||||
|
||||
|
||||
def extract_symmetric_allocation(allocation_mode: str) -> Dict | None:
|
||||
for x, y, z in itertools.permutations(["d", "m", "p"]):
|
||||
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
||||
m = re.match(pattern, allocation_mode)
|
||||
class AllocationType(enum.Enum):
|
||||
DECOUPLED = 1
|
||||
GLOBAL_HYBRID = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationMode:
|
||||
type_: AllocationType
|
||||
parallel_strat: Dict[str, Dict[str, int]]
|
||||
|
||||
def is_decoupled(self):
|
||||
return self.type_ == AllocationType.DECOUPLED
|
||||
|
||||
def is_global_hybrid(self):
|
||||
return self.type_ == AllocationType.GLOBAL_HYBRID
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, allocation_mode: str):
|
||||
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
|
||||
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
|
||||
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
|
||||
if alloc_decoupled:
|
||||
return cls(AllocationType.DECOUPLED, alloc_decoupled)
|
||||
if alloc_3d:
|
||||
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
|
||||
if alloc_hybrid:
|
||||
return cls(AllocationType.GLOBAL_HYBRID, alloc_hybrid)
|
||||
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
|
||||
|
||||
@staticmethod
|
||||
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
|
||||
for x, y, z in itertools.permutations(["d", "m", "p"]):
|
||||
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
||||
m = re.match(pattern, allocation_mode)
|
||||
if not m:
|
||||
continue
|
||||
a, b, c = map(int, m.groups())
|
||||
# to be consistent with the key-value pattern
|
||||
return {
|
||||
"*": {
|
||||
x: a,
|
||||
y: b,
|
||||
z: c,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
||||
pattern = re.compile(
|
||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||
)
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
continue
|
||||
a, b, c = map(int, m.groups())
|
||||
return {
|
||||
x: a,
|
||||
y: b,
|
||||
z: c,
|
||||
}
|
||||
|
||||
|
||||
def extract_decoupled_vllm_train_allocation(allocation_mode: str) -> Dict | None:
|
||||
pattern = re.compile(r"(?:vllm\.(.+?)\+(.+))|(?:(.+?)\+vllm\.(.+))")
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
return
|
||||
if m.group(1):
|
||||
vllm_alloc = m.group(1)
|
||||
other_alloc = m.group(2)
|
||||
else:
|
||||
vllm_alloc = m.group(4)
|
||||
other_alloc = m.group(3)
|
||||
vllm_alloc = extract_symmetric_allocation(vllm_alloc)
|
||||
other_alloc = extract_symmetric_allocation(other_alloc)
|
||||
if not vllm_alloc:
|
||||
return
|
||||
if not other_alloc:
|
||||
return
|
||||
other_alloc.update({"vllm." + k: v for k, v in vllm_alloc.items()})
|
||||
return other_alloc
|
||||
|
||||
|
||||
def parse_key_value_pairs(s: str):
|
||||
pattern = re.compile(r"([^:,]+):([^:,]+)")
|
||||
matches = pattern.findall(s)
|
||||
if not matches:
|
||||
return None
|
||||
return {key: value for key, value in matches}
|
||||
|
||||
|
||||
def extract_key_value_allocation(
|
||||
allocation_mode: str,
|
||||
) -> Dict[str, Dict[str, int]] | None:
|
||||
allocs = parse_key_value_pairs(allocation_mode)
|
||||
if not allocs:
|
||||
return
|
||||
for k, v in allocs.items():
|
||||
v = extract_symmetric_allocation(v)
|
||||
if not v:
|
||||
return
|
||||
allocs[k] = v
|
||||
return allocs
|
||||
if m.group(1):
|
||||
gen_alloc = m.group(1)
|
||||
other_alloc = m.group(2)
|
||||
else:
|
||||
gen_alloc = m.group(4)
|
||||
other_alloc = m.group(3)
|
||||
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
|
||||
if not gen_alloc:
|
||||
return
|
||||
other_alloc = AllocationMode.extract_3d_alloc(
|
||||
other_alloc
|
||||
) or AllocationMode.extract_key_value_alloc(other_alloc)
|
||||
if not other_alloc:
|
||||
return
|
||||
other_alloc.update({"gen": gen_alloc["*"]})
|
||||
return other_alloc
|
||||
|
||||
@staticmethod
|
||||
def extract_key_value_alloc(
|
||||
allocation_mode: str,
|
||||
) -> Dict[str, Dict[str, int]] | None:
|
||||
def parse_key_value_pairs(s: str):
|
||||
pattern = re.compile(r"([^:,]+):([^:,]+)")
|
||||
matches = pattern.findall(s)
|
||||
if not matches:
|
||||
return None
|
||||
return {key: value for key, value in matches}
|
||||
|
||||
allocs = parse_key_value_pairs(allocation_mode)
|
||||
if not allocs:
|
||||
return
|
||||
for k, v in allocs.items():
|
||||
v = AllocationMode.extract_3d_alloc(v)
|
||||
if not v:
|
||||
return
|
||||
allocs[k] = v["*"]
|
||||
return allocs
|
||||
|
|
Loading…
Reference in New Issue