AReaL/realhf/experiments/common/utils.py

326 lines
12 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import collections
import dataclasses
import enum
import itertools
import re
from typing import *
import numpy as np
from omegaconf import DictConfig, OmegaConf
from realhf.api.core.config import (
ModelBackendAbstraction,
ModelInterfaceType,
ModelName,
)
from realhf.api.core.dfg import OffloadHook, ParamReallocHook
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import (
ModelTrainEvalConfig,
ParallelismConfig,
parallelism_eq,
)
from realhf.base import logging
from realhf.base.topology import PipeModelDataParallelTopology
logger = logging.getLogger("Experiment Common Utils", "benchmark")
def get_topo(
parallel: ParallelismConfig,
gradient_checkpointing: bool,
gradient_accumulation_fusion: bool,
max_prompt_len: Optional[int] = None,
) -> PipeModelDataParallelTopology:
return PipeModelDataParallelTopology(
num_mp=parallel.model_parallel_size,
num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
max_prompt_len=max_prompt_len,
gradient_accumulation_fusion=gradient_accumulation_fusion,
)
def get_world_size(parallel: ParallelismConfig) -> int:
return (
parallel.model_parallel_size
* parallel.pipeline_parallel_size
* parallel.data_parallel_size
)
def make_train_backend_config(
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
):
if model_cfg.backend == "deepspeed":
return ModelBackendAbstraction(
"deepspeed",
args=dict(
optimizer_name="adam",
optimizer_config=dict(
lr=model_cfg.optimizer.lr,
weight_decay=model_cfg.optimizer.weight_decay,
eps=model_cfg.optimizer.eps,
betas=(
model_cfg.optimizer.beta1,
model_cfg.optimizer.beta2,
),
),
lr_scheduler_type=model_cfg.optimizer.lr_scheduler_type,
warmup_steps_proportion=model_cfg.optimizer.warmup_steps_proportion,
min_lr_ratio=model_cfg.optimizer.min_lr_ratio,
zero_stage=(
model_cfg.zero_stage
if parallel_cfg.pipeline_parallel_size == 1
else min(model_cfg.zero_stage, 1)
),
offload_optimizer_state=model_cfg.optimizer.offload,
offload_param=model_cfg.offload,
enable_bf16=model_cfg.enable_bf16,
enable_fp16=model_cfg.enable_fp16,
),
)
elif model_cfg.backend == "megatron":
if model_cfg.optimizer.offload or model_cfg.offload:
raise ValueError("Offload is not supported in Megatron backend.")
if model_cfg.zero_stage == 3:
raise ValueError("Zero stage 3 is not supported in Megatron backend.")
if model_cfg.zero_stage == 2:
logger.warning(
"Megatron does not support ZeRO stage 2. Degenerates to stage 1."
)
model_cfg.zero_stage = 1
megatron_args: Dict[str, Any] = OmegaConf.to_container(model_cfg.megatron)
return ModelBackendAbstraction(
"megatron",
args=dict(
enable_bf16=model_cfg.enable_bf16,
enable_fp16=model_cfg.enable_fp16,
zero_stage=model_cfg.zero_stage,
optimizer=model_cfg.optimizer,
**megatron_args,
),
)
elif model_cfg.backend == "mock_train":
return ModelBackendAbstraction(
"mock_train",
args=dict(
optimizer_name="adam",
optimizer_config=dict(
lr=model_cfg.optimizer.lr,
weight_decay=model_cfg.optimizer.weight_decay,
eps=model_cfg.optimizer.eps,
betas=(
model_cfg.optimizer.beta1,
model_cfg.optimizer.beta2,
),
),
),
)
else:
raise NotImplementedError(f"Backend {model_cfg.backend} is not supported.")
def make_inf_backend_config(
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
):
return ModelBackendAbstraction("inference")
def resolve_replica_ids(
rpc_allocs: List[RPCAllocation], models: Dict[str, ModelTrainEvalConfig]
):
role_cnt = collections.defaultdict(int)
first_device_mesh = dict()
first_parallel = dict()
first_rpc = dict()
for alloc in rpc_allocs:
rpc = alloc.rpc
if rpc.role not in first_device_mesh:
first_device_mesh[rpc.role] = alloc.device_mesh
first_parallel[rpc.role] = alloc.parallel
first_rpc[rpc.role] = rpc
continue
model_cfg = models[rpc.role]
if (rpc.is_train() and first_rpc[rpc.role].is_generate()) or (
rpc.is_generate() and first_rpc[rpc.role].is_train()
):
if model_cfg.vllm.hybrid_train:
role_cnt[rpc.role] += 1
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
continue
if alloc.device_mesh != first_device_mesh[rpc.role] or not parallelism_eq(
alloc.parallel, first_parallel[rpc.role]
):
role_cnt[rpc.role] += 1
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
continue
assert rpc.model_name.replica_id == 0
def resolve_rpc_hooks(
rpc_allocs: List[RPCAllocation], model_configs: Dict[str, ModelTrainEvalConfig]
):
role_interface_types = collections.defaultdict(set)
for rpc_alloc in rpc_allocs:
role_interface_types[rpc_alloc.rpc.role].add(rpc_alloc.rpc.interface_type)
for i, rpc_alloc in enumerate(rpc_allocs):
rpc = rpc_alloc.rpc
parallel = rpc_alloc.parallel
device_mesh = rpc_alloc.device_mesh
# check param realloc hooks for train_step rpcs
if rpc.interface_type == ModelInterfaceType.TRAIN_STEP:
for j, other in enumerate(rpc_allocs):
if rpc.name == other.rpc.name:
continue
if rpc.role != other.rpc.role:
continue
if (
parallelism_eq(parallel, other.parallel)
and device_mesh == other.device_mesh
and not (
model_configs[rpc.role].vllm.hybrid_train
and other.rpc.is_generate()
)
):
continue
self_config = model_configs[rpc.model_name.role]
other_config = model_configs[other.rpc.model_name.role]
if (
self_config.backend == "deepspeed"
or other_config.backend == "deepspeed"
):
raise ValueError(
"Param realloc hooks are not supported in DeepSpeed backend."
)
other.rpc.add_pre_hook(ParamReallocHook(source=rpc.model_name))
other.rpc.add_post_hook(ParamReallocHook(target=rpc.model_name))
logger.info(
f"Add param sync hooks between "
f"{rpc.name} and {other.rpc.name} for role {rpc.role}"
)
# Add offload hooks for inference and generate rpcs.
# Add the offload hook only if the role will not be trained (e.g., reward model)
# and its allocation is overlapped with at least one other RPCs.
# As a result, a single inference/generate RPC will not be offloaded.
overlapped_with_other = False
for other in rpc_allocs:
if rpc.name == other.rpc.name:
continue
if np.any(np.logical_and(other.device_mesh.mapping, device_mesh.mapping)):
overlapped_with_other = True
break
if (
ModelInterfaceType.TRAIN_STEP not in role_interface_types[rpc.role]
and overlapped_with_other
):
rpc.add_post_hook(OffloadHook())
logger.info(f"Add offload hook for rpc {rpc.name} for role {rpc.role}")
class AllocationType(enum.Enum):
DECOUPLED = 1
GLOBAL_HYBRID = 2
@dataclasses.dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: Dict[str, Dict[str, int]]
def is_decoupled(self):
return self.type_ == AllocationType.DECOUPLED
def is_global_hybrid(self):
return self.type_ == AllocationType.GLOBAL_HYBRID
@classmethod
def from_str(cls, allocation_mode: str):
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
if alloc_decoupled:
return cls(AllocationType.DECOUPLED, alloc_decoupled)
if alloc_3d:
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
if alloc_hybrid:
return cls(AllocationType.GLOBAL_HYBRID, alloc_hybrid)
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
@staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
for x, y, z in itertools.permutations(["d", "m", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
if not m:
continue
a, b, c = map(int, m.groups())
# to be consistent with the key-value pattern
return {
"*": {
x: a,
y: b,
z: c,
}
}
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:
return
if m.group(1):
gen_alloc = m.group(1)
other_alloc = m.group(2)
else:
gen_alloc = m.group(4)
other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
if not gen_alloc:
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc
@staticmethod
def extract_key_value_alloc(
allocation_mode: str,
) -> Dict[str, Dict[str, int]] | None:
def parse_key_value_pairs(s: str):
pattern = re.compile(r"([^:,]+):([^:,]+)")
matches = pattern.findall(s)
if not matches:
return None
return {key: value for key, value in matches}
allocs = parse_key_value_pairs(allocation_mode)
if not allocs:
return
for k, v in allocs.items():
v = AllocationMode.extract_3d_alloc(v)
if not v:
return
allocs[k] = v["*"]
return allocs
def asdict(cfg):
if isinstance(cfg, (OmegaConf, DictConfig)):
return OmegaConf.to_container(cfg, resolve=True)
return dataclasses.asdict(cfg)