PullRequest: 11 支持训练时自动拉起evaluate任务

Merge branch mzy/auto-eval of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/11

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* test
* move evaluator to main process
* .
* clear codes
* add docstring
* .
* separate wandb groups
* .
This commit is contained in:
晓雷 2025-03-05 18:06:40 +08:00
parent deccd47a22
commit e9bf229581
14 changed files with 347 additions and 14 deletions

View File

@ -5,6 +5,7 @@ import subprocess
from glob import glob
import numpy as np
import wandb
from rm_maj_eval import group_pred
from tqdm import tqdm
from transformers import AutoTokenizer
@ -29,6 +30,7 @@ def parse_args():
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--evaluate_train", action="store_true")
parser.add_argument("--max_gen_tokens", default=32768, type=int)
args = parser.parse_args()
if args.output_path is None:
args.output_path = args.model_path
@ -145,7 +147,7 @@ def process_single_data_name(args, data_name, base_dir, tokenizer):
if __name__ == "__main__":
args = parse_args()
print(f"Evaluation output to {args.output_path}")
assert args.num_sample_nodes * args.samples_per_node >= args.n_sampling
eval_dir = (
@ -155,6 +157,7 @@ if __name__ == "__main__":
)
base_dir = os.path.join(args.output_path, eval_dir)
os.makedirs(base_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
result_path = os.path.join(base_dir, f"aggregate_parallel_{args.prompt_type}.json")
@ -228,10 +231,10 @@ if __name__ == "__main__":
from prettytable import PrettyTable
table = PrettyTable()
filed_names = ["dataset"] + list(all_results[args.data_names[0]].keys())
table.field_names = filed_names
field_names = ["dataset"] + list(all_results[args.data_names[0]].keys())
table.field_names = field_names
for k, v in all_results.items():
table.add_row([k, *[round(v[x], 1) for x in filed_names[1:]]])
table.add_row([k, *[round(v[x], 1) for x in field_names[1:]]])
print(table)
except:

View File

@ -137,7 +137,7 @@ def generate_in_parallel(requests, model_args, sampling_params, data_parallel_si
def run_inference_one_model(
model_args: dict, sampling_params, requests, cuda_visisble_devices
):
os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
os.environ["VLLM_LOGGING_LEVEL"] = "INFO"
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(x) for x in cuda_visisble_devices]
)

View File

@ -3,7 +3,7 @@ vllm
tqdm
datasets
torch
transformers
transformers==4.47.0
python_dateutil
flash_attn
@ -12,4 +12,6 @@ sympy==1.12
antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy.
word2number
Pebble
timeout-decorator
prettytable
timeout-decorator
wandb

View File

@ -0,0 +1,7 @@
#!/bin/bash
# Users should run this script under AReaL directory
/usr/bin/python3 -m pip install -e evaluation/latex2sympy/
/usr/bin/python3 -m pip install -r evaluation/requirements.txt
cd evaluation && /usr/bin/python3 eval_and_aggregate.py --model_path $1 --output_path $2 --data_names $3 --max_gen_tokens $4 --prompt_type $5

View File

@ -234,6 +234,36 @@ class ExperimentScheduling:
controller_image: str = _LLM_CPU_IMAGE
@dataclasses.dataclass
class AutomaticEvaluator:
"""Configuration for automatic evaluation.
:param data_names: Dataset for evaluation seperated by comma. Currently support datasets stored under ./evaluation/data,
including "aime24", "amc23" and "math_500". For example, if "aime24" and "amc23" are required for evaluation,
this field should be set to "aime24,amc23".
:type data_names: str
:param max_gen_tokens: Maximum number of tokens to be generated in evaluation.
:type max_gen_tokens: int
:param max_concurrent_jobs: Maximum number of concurrent evaluation jobs to submit. If number of existing jobs is equal to
`max_concurrent_jobs` and a new checkpoint is saved, the evaluation job will wait until former jobs complete.
:type max_concurrent_jobs: int
:param eval_job_image: Container image used to launch evaluation job. If set to None, evaluation jobs will use
GPU image for training.
:type eval_job_image: Optional[str]
:param initial_checkpoint_path: Initial checkpoint path to evaluate. If specified, this initial checkpoint will be evaluated,
results will be stored as global_step = 0.
:type initial_checkpoint_path: Optional[str]
:param prompt_type: Prompt format used in evaluation.
:type prompt_type: str
"""
data_names: str = "aime24"
max_gen_tokens: int = 32768
max_concurrent_jobs: int = 3
eval_job_image: Optional[str] = None
initial_checkpoint_path: Optional[str] = None
prompt_type: str = "deepscaler"
@dataclasses.dataclass
class WandBConfig:
mode: str = "disabled"
@ -256,6 +286,11 @@ class ExperimentConfig:
model_worker: List[ModelWorker] = dataclasses.field(default_factory=list)
# master_worker will be set automatically
master_worker: Optional[List[MasterWorker]] = None
# automatic evaluation
auto_eval: bool = False
evaluator: AutomaticEvaluator = dataclasses.field(
default_factory=AutomaticEvaluator
)
def __post_init__(self):
self.master_worker = [

View File

@ -18,6 +18,7 @@ import realhf.base.recover as recover
import realhf.scheduler.client as sched_client
import realhf.system as system
from realhf.scheduler.client import JobException, JobState
from realhf.scheduler.evaluator import AutomaticEvaluator
logger = logging.getLogger("main", "system")
@ -85,10 +86,17 @@ def main_start(args, recover_count: int = 0):
# Run initial_setup to go through all sanity checks.
try:
exp_cfg = experiment.initial_setup()
assert isinstance(exp_cfg, config_package.ExperimentConfig)
exp_cfg.lazy_init()
except Exception as e:
raise RuntimeError("Experiment initial setup failed.") from e
evaluator = (
AutomaticEvaluator(exp_cfg.evaluator, exp_cfg.wandb)
if exp_cfg.auto_eval
else None
)
if args.mode == "local":
assert (
args.recover_mode == "disabled"
@ -167,6 +175,7 @@ def main_start(args, recover_count: int = 0):
expr_name=expr_name,
trial_name=trial_name,
schedule_strategy=args.schedule_strategy,
evaluator=evaluator,
)
setup = experiment.scheduling_setup()

View File

@ -105,7 +105,6 @@ BASE_ENVIRONS = {
# "TORCH_SHOW_CPP_STACKTRACES": "1",
# "RAY_DEDUP_LOGS": "0", # disable ray log deduplication
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"PYTHONUSERBASE": "/nonsense", # a random PYTHONUSERBASE to avoid local user site-packages interference
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow

View File

@ -29,6 +29,7 @@ from realhf.api.core.config import (
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
from realhf.api.core.system_api import (
AutomaticEvaluator,
Experiment,
ExperimentConfig,
ExperimentSaveEvalControl,
@ -183,6 +184,12 @@ class CommonExperimentConfig(Experiment):
torch.cuda.empty_cache() before each RPC in model worker
If enabled, there will be a ~0.1s overhead per RPC.
:type torch_cache_mysophobia: bool
:param auto_eval: Whether to automatic evaluation in training. When enabled, an evaluation
job is submitted whenever a checkpoint is saved, and the result will be logged on disk and
on wandb if wandb is active.
:type auto_eval: bool
:param auto_eval_config: Configuration for automatic evaluation.
:type auto_eval_config: AutomaticEvaluator
:param cpus_per_master_worker: The number of CPUs for each master worker.
:param mem_per_master_worker: The size of memory for each master worker, measured in MB.
:param cpus_per_model_worker: The number of CPUs for each model worker.
@ -214,6 +221,12 @@ class CommonExperimentConfig(Experiment):
default_factory=ExperimentSaveEvalControl
)
torch_cache_mysophobia: bool = True
# Options for automatic evaluation
auto_eval: bool = False
auto_eval_config: AutomaticEvaluator = dataclasses.field(
default_factory=AutomaticEvaluator
)
# Options for worker resources
cpus_per_master_worker: int = 4
mem_per_master_worker: int = 20000
cpus_per_model_worker: int = 4
@ -716,6 +729,8 @@ class CommonExperimentConfig(Experiment):
wandb=self.wandb,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
auto_eval=self.auto_eval,
evaluator=self.auto_eval_config,
)
def __check_legal_allocation_options(self):

View File

@ -544,12 +544,15 @@ class PPOMATHConfig(CommonExperimentConfig):
######### The main difference from normal PPO #########
model_worker = self._get_model_worker_configs(rpc_allocs)
self.auto_eval_config.initial_checkpoint_path = self.actor.path
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
auto_eval=self.auto_eval,
evaluator=self.auto_eval_config,
)

View File

@ -162,7 +162,8 @@ def make(mode, expr_name, trial_name, **kwargs) -> SchedulerClient:
from realhf.scheduler.slurm.client import SlurmSchedulerClient
schedule_strategy = kwargs.get("schedule_strategy", "empty_first")
return SlurmSchedulerClient(expr_name, trial_name, schedule_strategy)
evaluator = kwargs.get("evaluator", None)
return SlurmSchedulerClient(expr_name, trial_name, schedule_strategy, evaluator)
elif mode == "local":
from realhf.scheduler.local.client import LocalSchedulerClient

View File

@ -0,0 +1,246 @@
import json
import os
import re
import subprocess
import time
from typing import Dict
import wandb
import realhf.api.core.system_api as config_pkg
from realhf.base import cluster, constants, logging
logger = logging.getLogger("AutomaticEvaluator", "colored")
class AutomaticEvaluator:
def __init__(
self,
config: config_pkg.AutomaticEvaluator,
wandb_config: config_pkg.WandBConfig,
):
self.__running_processes: Dict[int, subprocess.Popen] = {}
self.__start_time = {}
self.__done_steps = []
self.__wandb_log_steps = []
self.__pending_ckpts = {}
self.__image = config.eval_job_image or cluster.spec.gpu_image
self.__data_names = config.data_names
self.__max_gen_tokens = config.max_gen_tokens
self.__max_concurrent_jobs = config.max_concurrent_jobs
self.__prompt_type = config.prompt_type
self.__wandb_config = wandb_config
self.__config = config
self.__wandb_inited = False
# Check evaluated checkpoints by logs
former_output_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"eval_output",
)
if os.path.exists(former_output_dir):
for log_dir in os.listdir(former_output_dir):
match = re.match(r"globalstep(\d+)", log_dir)
if not match:
continue
global_step = int(match.groups()[0])
self.__done_steps.append(global_step)
logger.info(
f"Initializing AutomaticEvaluator: \n"
f"eval_job_image: {config.eval_job_image}\n"
f"data_names: {config.data_names}\n"
f"max_gen_tokens: {config.max_gen_tokens}\n"
f"max_concurrent_jobs: {config.max_concurrent_jobs}\n"
f"Existing eval outputs for global steps: {self.__done_steps}"
)
if self.__config.initial_checkpoint_path and 0 not in self.__done_steps:
self.__pending_ckpts[0] = self.__config.initial_checkpoint_path
if not cluster.spec.cluster_type == "slurm":
raise NotImplementedError(
"Currently only support automatic evaluation for slurm"
)
def __lazy_wandb_init(self):
# Initializing wandb for evaluator
wandb.login()
wandb.init(
mode=self.__wandb_config.mode,
entity=self.__wandb_config.entity,
project=self.__wandb_config.project or constants.experiment_name(),
name=self.__wandb_config.name or f"{constants.trial_name()}_eval",
job_type=self.__wandb_config.job_type,
group=self.__wandb_config.group
or f"{constants.experiment_name()}_{constants.trial_name()}",
notes=self.__wandb_config.notes,
tags=self.__wandb_config.tags,
config=self.__wandb_config.config,
dir=os.path.join(
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
),
force=True,
id=f"{constants.experiment_name()}_{constants.trial_name()}_eval",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
def __check_new_ckpts(self):
save_path = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
"actor",
)
if not os.path.exists(save_path):
return
for ckpt_dir in os.listdir(save_path):
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", ckpt_dir)
if not match:
continue
_, _, global_step = map(int, match.groups())
if not global_step in (
list(self.__running_processes.keys())
+ list(self.__pending_ckpts.keys())
+ self.__done_steps
):
abs_ckpt_dir = os.path.join(save_path, ckpt_dir)
logger.info(
f"Found new checkpoint (globalstep{global_step}) at {abs_ckpt_dir}"
)
self.__pending_ckpts[global_step] = os.path.join(
save_path, abs_ckpt_dir
)
def __check_and_maybe_submit_jobs(self):
for global_step, process in self.__running_processes.items():
result = process.poll()
if not result is None:
self.__done_steps.append(global_step)
start_time = self.__start_time[global_step]
logger.info(
f"Evaluation of checkpoint (globalstep{global_step}) is done, returncode={process.returncode}, "
f"time passed {time.perf_counter() - start_time:.3f} s."
)
for done in self.__done_steps:
if done in self.__running_processes:
self.__running_processes.pop(done)
self.__start_time.pop(done)
submitted = []
# Jobs should be submitted by the order of global steps
ordered_steps = sorted(self.__pending_ckpts.keys())
for global_step in ordered_steps:
ckpt_path = self.__pending_ckpts[global_step]
if len(self.__running_processes) >= self.__max_concurrent_jobs:
return
self.__submit_one(global_step, ckpt_path)
submitted.append(global_step)
for global_step in submitted:
self.__pending_ckpts.pop(global_step)
def __submit_one(self, global_step, ckpt_path):
output_path = self.eval_output_path(global_step)
os.makedirs(output_path, exist_ok=True)
log_file = open(os.path.join(output_path, "output.log"), "w")
if cluster.spec.cluster_type == "slurm":
cmd = self.slurm_eval_cmd(global_step, ckpt_path)
else:
raise NotImplementedError(
"AutomaticEvaluator does only support slurm cluster."
)
logger.info(
f"Submitting evaluation job of checkpoint at {ckpt_path} (globalstep{global_step}), "
f"command: {cmd}"
)
self.__running_processes[global_step] = subprocess.Popen(
cmd,
stdout=log_file,
stderr=log_file,
shell=True,
)
self.__start_time[global_step] = time.perf_counter()
def __maybe_parse_and_log_to_wandb(self):
# Note that after recover, all previous done steps will be
# logged again in case some data points are missing.
# If the data point is already logged, wandb will raise
# a warning.
to_log = list(
filter(
lambda x: x not in self.__wandb_log_steps,
(
self.__done_steps
+ list(self.__running_processes.keys())
+ list(self.__pending_ckpts.keys())
),
)
)
while to_log:
# The wandb should always log the minimal global step
# whose eval job has been submitted but not logged.
# If this minimal step is not logged, other steps should wait.
global_step = min(to_log)
result_path = os.path.join(
self.eval_output_path(global_step),
f"math_eval_{self.__max_gen_tokens}",
f"aggregate_parallel_{self.__prompt_type}.json",
)
if not os.path.exists(result_path):
break
if not self.__wandb_inited:
self.__lazy_wandb_init()
self.__wandb_inited = True
try:
with open(result_path, "r") as fp:
data = json.load(fp)
except json.JSONDecodeError:
logger.warning(f"JSON decoding for eval result in {result_path} failed")
continue
except FileNotFoundError:
logger.warning(
f"{result_path} not found, but the eval job is done. "
"Maybe the eval job abnormally exited and did not output the result."
)
continue
wandb_data = {}
for data_name, d in data.items():
for k, v in d.items():
wandb_data[f"{data_name}_{k}"] = v
wandb.log(wandb_data, step=global_step)
logger.info(f"Logging eval result {wandb_data} to step {global_step}")
self.__wandb_log_steps.append(global_step)
to_log.remove(global_step)
def step(self):
self.__check_new_ckpts()
self.__check_and_maybe_submit_jobs()
self.__maybe_parse_and_log_to_wandb()
def slurm_eval_cmd(self, global_step, ckpt_path):
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{global_step}"
cmd = (
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
f"singularity exec --nv --pid --writable-tmpfs --bind /storage:/storage "
f"{self.__image} "
f"bash ./evaluation/sh/install_deps_and_eval.sh {ckpt_path} {self.eval_output_path(global_step)} "
f"{self.__data_names} {self.__max_gen_tokens} {self.__prompt_type}"
)
return cmd
def eval_output_path(self, global_step):
return os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"eval_output",
f"globalstep{global_step}",
)

View File

@ -14,6 +14,7 @@ import realhf.base.logging as logging
from realhf.base.cluster import spec as cluster_spec
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
from realhf.scheduler.evaluator import AutomaticEvaluator
from realhf.scheduler.slurm.utils import (
SlurmLaunchInfo,
SlurmResource,
@ -25,12 +26,19 @@ logger = logging.getLogger("Slurm-scheduler")
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
class SlurmSchedulerClient(SchedulerClient):
"""Uses Slurm (https://slurm.schedmd.com/overview.html)."""
def __init__(self, expr_name, trial_name, schedule_strategy):
def __init__(
self,
expr_name: str,
trial_name: str,
schedule_strategy: str,
evaluator: Optional[AutomaticEvaluator],
):
super().__init__(expr_name, trial_name)
self.__schedule_strategy = schedule_strategy
@ -40,6 +48,7 @@ class SlurmSchedulerClient(SchedulerClient):
self.__submission_counter = defaultdict(int)
self.__wprocs_counter = defaultdict(int)
self.__evaluator = evaluator
def submit(self, worker_type, cmd, **kwargs):
self.submit_array(worker_type, cmd, count=1, **kwargs)
@ -245,6 +254,8 @@ class SlurmSchedulerClient(SchedulerClient):
if len(left) < num_jobs_left:
num_jobs_left = len(left)
logger.info(f"Waiting for {num_jobs_left} jobs.")
if self.__evaluator is not None:
self.__evaluator.step()
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
@ -276,7 +287,7 @@ class SlurmSchedulerClient(SchedulerClient):
left.remove(job_slurm_name)
if update:
self.__committed_jobs.pop(job_slurm_name)
time.sleep(2)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
def __update_all(self):
states = []

View File

@ -450,7 +450,7 @@ class RayController:
It uses the basic Controller to configure workers. Besides, it
launchs all remote workers using Ray, instead of submitting them to
the scheduelr.
the scheduler.
"""
def __init__(self, experiment_name, trial_name):

View File

@ -963,9 +963,10 @@ class MasterWorker(worker_base.Worker):
mode=self.wandb_config.mode,
entity=self.wandb_config.entity,
project=self.wandb_config.project or constants.experiment_name(),
name=self.wandb_config.name or constants.trial_name(),
name=self.wandb_config.name or f"{constants.trial_name()}_train",
job_type=self.wandb_config.job_type,
group=self.wandb_config.group,
group=self.wandb_config.group
or f"{constants.experiment_name()}_{constants.trial_name()}",
notes=self.wandb_config.notes,
tags=self.wandb_config.tags,
config=self.wandb_config.config,
@ -973,6 +974,7 @@ class MasterWorker(worker_base.Worker):
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
),
force=True,
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)