PullRequest: 7 Support the key-value allocation when using decoupled vLLM generation.

Merge branch fw/vllm-key-value-alloc of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/7?tab=commit

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* remove pring
This commit is contained in:
博惟 2025-03-02 13:13:47 +08:00
parent e7c4a49adc
commit ceee49454a
2 changed files with 145 additions and 128 deletions

View File

@ -15,6 +15,7 @@ from typing import *
import numpy as np
import transformers
from omegaconf import MISSING, OmegaConf
from transformers.utils import is_accelerate_available
import realhf.base.logging as logging
from realhf.api.core.config import (
@ -58,9 +59,7 @@ from realhf.experiments.common.check import (
check_valid_vllm,
)
from realhf.experiments.common.utils import (
extract_decoupled_vllm_train_allocation,
extract_key_value_allocation,
extract_symmetric_allocation,
AllocationMode,
get_topo,
make_inf_backend_config,
make_train_backend_config,
@ -111,10 +110,6 @@ class CommonExperimentConfig(Experiment):
- ``heuristic``\: Allocate resources and configure parallel strategies using heuristic strategies obtained from a search.
- ``pipe_data``\: Identical parallelization (like DSChat) with pipe+data parallelism. For a world size under 8, only data parallelism will be used.
- ``pipe_model``\: Identical parallelization (like DSChat) with pipe+model parallelism. For a world size under 8, only tensor-model parallelism will be used.
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
- A regex pattern like ``vllm.{IdentPara}+{IdentPara}``\: Decoupled generation (vLLM) and training allocations with correspnding identical parallelization strategies. Note that the pipeline parallel degree of vLLM can only be 1.
@ -204,7 +199,7 @@ class CommonExperimentConfig(Experiment):
recover_retries: int = 1
recover_after: int = 10
ignore_worker_error: bool = False
allocation_mode: str = "pipe_model"
allocation_mode: str = ""
allocation_use_cache: bool = False
n_nodes: int = 1
n_gpus_per_node: int = cluster_spec.n_gpus_per_node
@ -358,6 +353,8 @@ class CommonExperimentConfig(Experiment):
self.__check_legal_allocation_options()
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
rpcs = self.rpcs
if self.allocation_mode == "search":
# assert self.mode == "slurm"
@ -372,90 +369,72 @@ class CommonExperimentConfig(Experiment):
break
else:
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
elif (
self.allocation_mode == "pipe_data"
or self.allocation_mode == "pipe_model"
or extract_symmetric_allocation(self.allocation_mode)
):
if self.allocation_mode == "pipe_data":
dp, pp, mp = self.n_gpus_per_node, self.n_nodes, 1
elif self.allocation_mode == "pipe_model":
dp, pp, mp = 1, self.n_nodes, self.n_gpus_per_node
else:
para = extract_symmetric_allocation(self.allocation_mode)
dp, pp, mp = para["d"], para["p"], para["m"]
if dp * pp * mp != self.n_nodes * self.n_gpus_per_node:
raise ValueError(
"The multiplication of 3D parallel degrees "
"does not equal to the number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
rpc_allocs: List[RPCAllocation] = [
RPCAllocation(
rpc=rpc,
device_mesh=self.global_device_mesh,
parallel=ParallelismConfig(
data_parallel_size=dp,
pipeline_parallel_size=pp,
model_parallel_size=mp,
use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and mp > 1
),
),
)
for rpc in rpcs.values()
]
elif extract_decoupled_vllm_train_allocation(self.allocation_mode):
para = extract_decoupled_vllm_train_allocation(self.allocation_mode)
dp, pp, mp = para["d"], para["p"], para["m"]
vdp, vpp, vmp = para["vllm.d"], para["vllm.p"], para["vllm.m"]
vllm_world_size = vdp * vpp * vmp
if dp * pp * mp + vdp * vpp * vmp != self.n_nodes * self.n_gpus_per_node:
raise ValueError(
"The multiplication of 3D parallel degrees "
"does not equal to the number of gpus. "
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
"so their summation should be equal to the total number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={vdp}, vllm.pp={vpp}, vllm.mp={vmp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
elif self._allocation_mode.is_decoupled():
paras = self._allocation_mode.parallel_strat
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
gen_world_size = gdp * gpp * gmp
assert (
vllm_world_size < self.n_gpus_per_node
or vllm_world_size % self.n_gpus_per_node == 0
gen_world_size < self.n_gpus_per_node
or gen_world_size % self.n_gpus_per_node == 0
)
vllm_device_mesh, train_device_mesh = self.global_device_mesh.split(
vllm_world_size
gen_device_mesh, train_device_mesh = self.global_device_mesh.split(
gen_world_size
)
self.vllm_device_mesh = vllm_device_mesh
self.gen_device_mesh = gen_device_mesh
self.train_device_mesh = train_device_mesh
rpc_allocs = []
flag = False
for rpc in rpcs.values():
if rpc.interface_type == ModelInterfaceType.GENERATE:
if vpp != 1:
if rpc.is_generate():
if gpp != 1:
raise NotImplementedError(
"vllm pipeline parallel is not supported yet."
"vllm/sglang pipeline parallel is not supported yet."
)
if flag:
raise NotImplementedError(
"vllm does not support two generation RPCs for now."
"vllm/sglang does not support two generation RPCs for now."
)
alloc = RPCAllocation(
rpc=rpc,
device_mesh=vllm_device_mesh,
device_mesh=gen_device_mesh,
parallel=ParallelismConfig(
data_parallel_size=vdp,
pipeline_parallel_size=vpp,
model_parallel_size=vmp,
data_parallel_size=gdp,
pipeline_parallel_size=gpp,
model_parallel_size=gmp,
use_sequence_parallel=False,
),
)
flag = True
else:
rpc_name = rpc.name
if rpc_name in paras:
dp, pp, mp = (
paras[rpc_name]["d"],
paras[rpc_name]["p"],
paras[rpc_name]["m"],
)
else:
if "*" not in paras:
raise ValueError(
f"RPC {rpc_name} parallel strategy not given, "
"expect a `*` to specify the default parallel strategy."
)
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
if (
dp * pp * mp + gdp * gpp * gmp
!= self.n_nodes * self.n_gpus_per_node
):
raise ValueError(
"The multiplication of 3D parallel degrees "
"does not equal to the number of gpus. "
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
"so their summation should be equal to the total number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={gdp}, vllm.pp={gpp}, vllm.mp={gmp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
alloc = RPCAllocation(
rpc=rpc,
device_mesh=train_device_mesh,
@ -472,10 +451,10 @@ class CommonExperimentConfig(Experiment):
rpc_allocs.append(alloc)
if not flag:
raise ValueError(
"No generation RPC found. Please use the allocation mode without vllm."
"No generation RPC found. Please use the hybrid train allocation mode."
)
elif extract_key_value_allocation(self.allocation_mode):
paras = extract_key_value_allocation(self.allocation_mode)
elif self._allocation_mode.is_global_hybrid():
paras = self._allocation_mode.parallel_strat
rpc_allocs = []
for rpc_name, rpc in self.rpcs.items():
if rpc_name in paras:
@ -558,8 +537,8 @@ class CommonExperimentConfig(Experiment):
# vLLM enabled model worker, shortcut case
if (
extract_decoupled_vllm_train_allocation(self.allocation_mode)
and self.vllm_device_mesh.mapping[i, j]
self._allocation_mode.is_decoupled()
and self.gen_device_mesh.mapping[i, j]
):
gen_rpc_alloc = next(
alloc

View File

@ -3,6 +3,8 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import collections
import dataclasses
import enum
import itertools
import re
from typing import *
@ -224,58 +226,94 @@ def resolve_rpc_hooks(
logger.info(f"Add offload hook for rpc {rpc.name} for role {rpc.role}")
def extract_symmetric_allocation(allocation_mode: str) -> Dict | None:
for x, y, z in itertools.permutations(["d", "m", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
class AllocationType(enum.Enum):
DECOUPLED = 1
GLOBAL_HYBRID = 2
@dataclasses.dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: Dict[str, Dict[str, int]]
def is_decoupled(self):
return self.type_ == AllocationType.DECOUPLED
def is_global_hybrid(self):
return self.type_ == AllocationType.GLOBAL_HYBRID
@classmethod
def from_str(cls, allocation_mode: str):
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
if alloc_decoupled:
return cls(AllocationType.DECOUPLED, alloc_decoupled)
if alloc_3d:
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
if alloc_hybrid:
return cls(AllocationType.GLOBAL_HYBRID, alloc_hybrid)
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
@staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
for x, y, z in itertools.permutations(["d", "m", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
if not m:
continue
a, b, c = map(int, m.groups())
# to be consistent with the key-value pattern
return {
"*": {
x: a,
y: b,
z: c,
}
}
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:
continue
a, b, c = map(int, m.groups())
return {
x: a,
y: b,
z: c,
}
def extract_decoupled_vllm_train_allocation(allocation_mode: str) -> Dict | None:
pattern = re.compile(r"(?:vllm\.(.+?)\+(.+))|(?:(.+?)\+vllm\.(.+))")
m = pattern.match(allocation_mode)
if not m:
return
if m.group(1):
vllm_alloc = m.group(1)
other_alloc = m.group(2)
else:
vllm_alloc = m.group(4)
other_alloc = m.group(3)
vllm_alloc = extract_symmetric_allocation(vllm_alloc)
other_alloc = extract_symmetric_allocation(other_alloc)
if not vllm_alloc:
return
if not other_alloc:
return
other_alloc.update({"vllm." + k: v for k, v in vllm_alloc.items()})
return other_alloc
def parse_key_value_pairs(s: str):
pattern = re.compile(r"([^:,]+):([^:,]+)")
matches = pattern.findall(s)
if not matches:
return None
return {key: value for key, value in matches}
def extract_key_value_allocation(
allocation_mode: str,
) -> Dict[str, Dict[str, int]] | None:
allocs = parse_key_value_pairs(allocation_mode)
if not allocs:
return
for k, v in allocs.items():
v = extract_symmetric_allocation(v)
if not v:
return
allocs[k] = v
return allocs
if m.group(1):
gen_alloc = m.group(1)
other_alloc = m.group(2)
else:
gen_alloc = m.group(4)
other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
if not gen_alloc:
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc
@staticmethod
def extract_key_value_alloc(
allocation_mode: str,
) -> Dict[str, Dict[str, int]] | None:
def parse_key_value_pairs(s: str):
pattern = re.compile(r"([^:,]+):([^:,]+)")
matches = pattern.findall(s)
if not matches:
return None
return {key: value for key, value in matches}
allocs = parse_key_value_pairs(allocation_mode)
if not allocs:
return
for k, v in allocs.items():
v = AllocationMode.extract_3d_alloc(v)
if not v:
return
allocs[k] = v["*"]
return allocs