AReaL/realhf/experiments/common/check.py

82 lines
2.8 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
from typing import List
import realhf.api.core.model_api as model_api
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
from realhf.base import logging
logger = logging.getLogger(__name__)
def check_is_realhf_native_impl(_cls):
return _cls.__module__.startswith("realhf")
def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
if vllm.hybrid_train and not vllm.enforce_eager:
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
def check_valid_optimizer(model: ModelTrainEvalConfig):
if model.optimizer.min_lr_ratio < 0.0 or model.optimizer.min_lr_ratio > 1.0:
raise ValueError(f"Invalid min_lr_ratio: {model.optimizer.min_lr_ratio}")
if (
model.optimizer.warmup_steps_proportion < 0.0
or model.optimizer.warmup_steps_proportion > 1.0
):
raise ValueError(
f"Invalid warmup_steps_proportion: {model.optimizer.warmup_steps_proportion}"
)
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
if not os.path.exists(model.path):
raise FileNotFoundError(
f"The model path `{model.path}` for `{role}` does not exist locally. "
"You must download the HuggingFace checkpoint before loading it."
)
def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
try:
rpc = rpc_alloc.rpc
mb_spec = rpc.mb_spec
dp_size = rpc_alloc.parallel.data_parallel_size
tp_size = rpc_alloc.parallel.model_parallel_size
pp_size = rpc_alloc.parallel.pipeline_parallel_size
factor = 1
if rpc.is_train() and rpc_alloc.parallel.pipeline_parallel_size > 1:
factor = 2
assert (
rpc.n_seqs
>= factor * dp_size * pp_size * rpc.min_n_seqs_per_pass * mb_spec.n_mbs
), (
rpc.name,
rpc.n_seqs,
mb_spec,
rpc.min_n_seqs_per_pass,
factor,
dp_size,
pp_size,
)
except AssertionError as e:
raise ValueError(
f"Invalid parallel batch size and batch size configuration."
) from e