PullRequest: 20 Fix bugs in auto evaluation

Merge branch mzy/refactor-eval of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/20?tab=diff

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* test
* move evaluator to main process
* .
* clear codes
* add docstring
* .
* separate wandb groups
* .
* handle eval error
* add check for failed eval
* refactor evaluator
* refactor evaluator, fix master worker wandb login
* .
This commit is contained in:
晓雷 2025-03-12 16:10:12 +08:00
parent 26a48be73e
commit 3d8be914af
4 changed files with 226 additions and 166 deletions

View File

@ -5,7 +5,6 @@ import subprocess
from glob import glob from glob import glob
import numpy as np import numpy as np
import wandb
from rm_maj_eval import group_pred from rm_maj_eval import group_pred
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -110,7 +109,6 @@ def get_metrics(fname_pattern, tokenizer, is_greedy):
"num_questions": len(lengths), "num_questions": len(lengths),
} }
else: else:
return { return {
"sample_length": np.mean(lengths), "sample_length": np.mean(lengths),
"sample_pass@1": pass_at_k(results.values(), 1), "sample_pass@1": pass_at_k(results.values(), 1),

View File

@ -1,9 +1,12 @@
import dataclasses
import enum
import json import json
import os import os
import pathlib
import re import re
import subprocess import subprocess
import time import time
from typing import Dict from typing import Dict, Optional
import wandb import wandb
@ -13,6 +16,138 @@ from realhf.base import cluster, constants, logging
logger = logging.getLogger("AutomaticEvaluator", "colored") logger = logging.getLogger("AutomaticEvaluator", "colored")
class EvaluationStepStatus(enum.Enum):
PENDING = 0
RUNNING = 1
FAILED = 2
DONE = 3
LOGGED = 4
@dataclasses.dataclass
class EvaluationStep:
global_step: int
status: EvaluationStepStatus
start_time: Optional[float] = None
ckpt_dir: Optional[str] = None
process: Optional[subprocess.Popen] = None
@staticmethod
def from_ckpt_dir(ckpt_dir):
# NOTE: ckpt_dir should be absolute path
if pathlib.Path(ckpt_dir).is_symlink():
return None
_dir = os.path.basename(ckpt_dir)
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", _dir)
if not match:
return None
_, _, global_step = map(int, match.groups())
return EvaluationStep(
global_step=global_step,
status=EvaluationStepStatus.PENDING,
ckpt_dir=ckpt_dir,
)
@staticmethod
def from_output_dir(output_dir):
# NOTE: output_dir should be absolute path
# Should only be called in recover.
_dir = os.path.basename(output_dir)
match = re.match(r"globalstep(\d+)", _dir)
if not match:
return None
global_step = int(match.groups()[0])
return EvaluationStep(
global_step=global_step, status=EvaluationStepStatus.LOGGED
)
@property
def output_dir(self):
return os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"eval_output",
f"globalstep{self.global_step}",
)
def slurm_eval_cmd(self, config: config_pkg.AutomaticEvaluator):
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{self.global_step}"
cmd = (
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage "
f"{config.eval_job_image or cluster.spec.gpu_image} "
f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} "
f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}"
)
return cmd
def submit(self, config: config_pkg.AutomaticEvaluator):
os.makedirs(self.output_dir, exist_ok=True)
log_file = open(os.path.join(self.output_dir, "output.log"), "w")
if cluster.spec.cluster_type == "slurm":
cmd = self.slurm_eval_cmd(config)
else:
raise NotImplementedError(
"AutomaticEvaluator does only support slurm cluster."
)
logger.info(
f"Submitting evaluation job of checkpoint at {self.ckpt_dir} (globalstep{self.global_step}), "
f"command: {cmd}"
)
self.process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=log_file,
shell=True,
)
self.start_time = time.perf_counter()
self.status = EvaluationStepStatus.RUNNING
def log(self, config: config_pkg.AutomaticEvaluator) -> bool:
result_path = os.path.join(
self.output_dir,
f"math_eval_{config.max_gen_tokens}",
f"aggregate_parallel_{config.prompt_type}.json",
)
# NOTE: If decoding json failed or not found,
# evaluation step will be marked as failed.
try:
with open(result_path, "r") as fp:
data = json.load(fp)
except json.JSONDecodeError:
logger.warning(f"JSON file {result_path} decoding failed.")
self.status = EvaluationStepStatus.FAILED
return False
except FileNotFoundError:
logger.warning(f"JSON file {result_path} does not exist.")
self.status = EvaluationStepStatus.FAILED
return False
wandb_data = {}
for data_name, d in data.items():
for k, v in d.items():
wandb_data[f"{data_name}_{k}"] = v
wandb.log(wandb_data, step=self.global_step)
self.status = EvaluationStepStatus.LOGGED
logger.info(f"Logging eval result {wandb_data} to step {self.global_step}")
return True
def check(self):
assert self.process is not None
result = self.process.poll()
if not result is None:
logger.info(
f"Evaluation of checkpoint (globalstep{self.global_step}) is done, returncode={self.process.returncode}, "
f"time passed {time.perf_counter() - self.start_time:.3f} s."
)
if self.process.returncode == 0:
self.status = EvaluationStepStatus.DONE
else:
self.status = EvaluationStepStatus.FAILED
class AutomaticEvaluator: class AutomaticEvaluator:
def __init__( def __init__(
@ -20,34 +155,31 @@ class AutomaticEvaluator:
config: config_pkg.AutomaticEvaluator, config: config_pkg.AutomaticEvaluator,
wandb_config: config_pkg.WandBConfig, wandb_config: config_pkg.WandBConfig,
): ):
self.__running_processes: Dict[int, subprocess.Popen] = {} self.__eval_steps: Dict[int, EvaluationStep] = {}
self.__start_time = {}
self.__done_steps = []
self.__wandb_log_steps = []
self.__pending_ckpts = {}
self.__image = config.eval_job_image or cluster.spec.gpu_image
self.__data_names = config.data_names
self.__max_gen_tokens = config.max_gen_tokens
self.__max_concurrent_jobs = config.max_concurrent_jobs self.__max_concurrent_jobs = config.max_concurrent_jobs
self.__prompt_type = config.prompt_type
self.__wandb_config = wandb_config self.__wandb_config = wandb_config
self.__config = config self.__config = config
self.__wandb_inited = False self.__wandb_initialized = False
# Check evaluated checkpoints by logs # Check evaluated checkpoints by logs in recover
former_output_dir = os.path.join( # NOTE: All previous evaluation steps with output will be marked
# as logged, even if it is not really logged in wandb.
# This is because we do not know the status of evaluation jobs
# submitted before recover.
# Resubmiting or waiting for these jobs will probably result in
# unexpected behaviors.
output_parent = os.path.join(
constants.LOG_ROOT, constants.LOG_ROOT,
constants.experiment_name(), constants.experiment_name(),
constants.trial_name(), constants.trial_name(),
"eval_output", "eval_output",
) )
if os.path.exists(former_output_dir): if os.path.exists(output_parent):
for log_dir in os.listdir(former_output_dir): for output_dir in os.listdir(output_parent):
match = re.match(r"globalstep(\d+)", log_dir) output_dir = os.path.join(output_parent, output_dir)
if not match: eval_step = EvaluationStep.from_output_dir(output_dir)
continue if eval_step:
global_step = int(match.groups()[0]) self.__eval_steps[eval_step.global_step] = eval_step
self.__done_steps.append(global_step)
logger.info( logger.info(
f"Initializing AutomaticEvaluator: \n" f"Initializing AutomaticEvaluator: \n"
@ -55,10 +187,15 @@ class AutomaticEvaluator:
f"data_names: {config.data_names}\n" f"data_names: {config.data_names}\n"
f"max_gen_tokens: {config.max_gen_tokens}\n" f"max_gen_tokens: {config.max_gen_tokens}\n"
f"max_concurrent_jobs: {config.max_concurrent_jobs}\n" f"max_concurrent_jobs: {config.max_concurrent_jobs}\n"
f"Existing eval outputs for global steps: {self.__done_steps}" f"Existing eval outputs for global steps: "
f"{list(self.__eval_steps.keys())}"
) )
if self.__config.initial_checkpoint_path and 0 not in self.__done_steps: if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps:
self.__pending_ckpts[0] = self.__config.initial_checkpoint_path self.__eval_steps[0] = EvaluationStep(
global_step=0,
status=EvaluationStepStatus.PENDING,
ckpt_dir=self.__config.initial_checkpoint_path,
)
if not cluster.spec.cluster_type == "slurm": if not cluster.spec.cluster_type == "slurm":
raise NotImplementedError( raise NotImplementedError(
@ -66,7 +203,10 @@ class AutomaticEvaluator:
) )
def __lazy_wandb_init(self): def __lazy_wandb_init(self):
# Initializing wandb for evaluator # Initializing wandb for evaluator.
# Here we use lazy init because if this wandb instance is launched
# with wandb instance on master worker without a time interval,
# one of them will fail.
wandb.login() wandb.login()
wandb.init( wandb.init(
mode=self.__wandb_config.mode, mode=self.__wandb_config.mode,
@ -88,159 +228,80 @@ class AutomaticEvaluator:
settings=wandb.Settings(start_method="fork"), settings=wandb.Settings(start_method="fork"),
) )
def __check_new_ckpts(self): def step(self):
save_path = os.path.join( # Check whether a new evaluation step should be created
ckpt_parent = os.path.join(
constants.MODEL_SAVE_ROOT, constants.MODEL_SAVE_ROOT,
constants.experiment_name(), constants.experiment_name(),
constants.trial_name(), constants.trial_name(),
"actor", "actor",
) )
if not os.path.exists(save_path): if os.path.exists(ckpt_parent):
return for ckpt_dir in os.listdir(ckpt_parent):
for ckpt_dir in os.listdir(save_path): ckpt_dir = os.path.join(ckpt_parent, ckpt_dir)
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", ckpt_dir) eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir)
if not match: if eval_step is None:
continue continue
_, _, global_step = map(int, match.groups()) if eval_step.global_step in self.__eval_steps:
if not global_step in ( continue
list(self.__running_processes.keys()) self.__eval_steps[eval_step.global_step] = eval_step
+ list(self.__pending_ckpts.keys())
+ self.__done_steps
):
abs_ckpt_dir = os.path.join(save_path, ckpt_dir)
logger.info( logger.info(
f"Found new checkpoint (globalstep{global_step}) at {abs_ckpt_dir}" f"Found new checkpoint (globalstep{eval_step.global_step}) "
) f"at {ckpt_dir}"
self.__pending_ckpts[global_step] = os.path.join(
save_path, abs_ckpt_dir
) )
def __check_and_maybe_submit_jobs(self): # Submit pending evaluation step
for global_step, process in self.__running_processes.items(): if self.__running_jobs < self.__max_concurrent_jobs:
result = process.poll() # Submit in global_step order
if not result is None: pending_steps = list(
self.__done_steps.append(global_step) filter(
start_time = self.__start_time[global_step] lambda x: self.__eval_steps[x].status
logger.info( == EvaluationStepStatus.PENDING,
f"Evaluation of checkpoint (globalstep{global_step}) is done, returncode={process.returncode}, " self.__eval_steps.keys(),
f"time passed {time.perf_counter() - start_time:.3f} s."
) )
for done in self.__done_steps:
if done in self.__running_processes:
self.__running_processes.pop(done)
self.__start_time.pop(done)
submitted = []
# Jobs should be submitted by the order of global steps
ordered_steps = sorted(self.__pending_ckpts.keys())
for global_step in ordered_steps:
ckpt_path = self.__pending_ckpts[global_step]
if len(self.__running_processes) >= self.__max_concurrent_jobs:
return
self.__submit_one(global_step, ckpt_path)
submitted.append(global_step)
for global_step in submitted:
self.__pending_ckpts.pop(global_step)
def __submit_one(self, global_step, ckpt_path):
output_path = self.eval_output_path(global_step)
os.makedirs(output_path, exist_ok=True)
log_file = open(os.path.join(output_path, "output.log"), "w")
if cluster.spec.cluster_type == "slurm":
cmd = self.slurm_eval_cmd(global_step, ckpt_path)
else:
raise NotImplementedError(
"AutomaticEvaluator does only support slurm cluster."
) )
if pending_steps:
min_pending = min(pending_steps)
self.__eval_steps[min_pending].submit(self.__config)
logger.info( # Check if any eval job is done or failed
f"Submitting evaluation job of checkpoint at {ckpt_path} (globalstep{global_step}), " running_steps = filter(
f"command: {cmd}" lambda x: self.__eval_steps[x].status == EvaluationStepStatus.RUNNING,
self.__eval_steps.keys(),
) )
self.__running_processes[global_step] = subprocess.Popen( for global_step in running_steps:
cmd, self.__eval_steps[global_step].check()
stdout=log_file,
stderr=log_file,
shell=True,
)
self.__start_time[global_step] = time.perf_counter()
def __maybe_parse_and_log_to_wandb(self): # Check whether the **minimal global step**, not logged or failed, is done,
# Note that after recover, all previous done steps will be # and log this step to wandb if done.
# logged again in case some data points are missing. # NOTE: LOGGED and FAILED steps have identical behaviors now.
# If the data point is already logged, wandb will raise # But in future versions that supports multi-node eval they could be different.
# a warning. log_steps = list(
to_log = list(
filter( filter(
lambda x: x not in self.__wandb_log_steps, lambda x: self.__eval_steps[x].status
( not in [
self.__done_steps EvaluationStepStatus.LOGGED,
+ list(self.__running_processes.keys()) EvaluationStepStatus.FAILED,
+ list(self.__pending_ckpts.keys()) ],
), self.__eval_steps.keys(),
) )
) )
while to_log: if log_steps:
# The wandb should always log the minimal global step log_step = min(log_steps)
# whose eval job has been submitted but not logged. if self.__eval_steps[log_step].status == EvaluationStepStatus.DONE:
# If this minimal step is not logged, other steps should wait. if not self.__wandb_initialized:
global_step = min(to_log) self.__lazy_wandb_init()
result_path = os.path.join( self.__wandb_initialized = True
self.eval_output_path(global_step), self.__eval_steps[log_step].log(self.__config)
f"math_eval_{self.__max_gen_tokens}",
f"aggregate_parallel_{self.__prompt_type}.json",
)
if not os.path.exists(result_path):
break
if not self.__wandb_inited: @property
self.__lazy_wandb_init() def __running_jobs(self):
self.__wandb_inited = True return len(
list(
try: filter(
with open(result_path, "r") as fp: lambda x: self.__eval_steps[x].status
data = json.load(fp) == EvaluationStepStatus.RUNNING,
except json.JSONDecodeError: self.__eval_steps.keys(),
logger.warning(f"JSON decoding for eval result in {result_path} failed")
continue
except FileNotFoundError:
logger.warning(
f"{result_path} not found, but the eval job is done. "
"Maybe the eval job abnormally exited and did not output the result."
) )
continue )
wandb_data = {}
for data_name, d in data.items():
for k, v in d.items():
wandb_data[f"{data_name}_{k}"] = v
wandb.log(wandb_data, step=global_step)
logger.info(f"Logging eval result {wandb_data} to step {global_step}")
self.__wandb_log_steps.append(global_step)
to_log.remove(global_step)
def step(self):
self.__check_new_ckpts()
self.__check_and_maybe_submit_jobs()
self.__maybe_parse_and_log_to_wandb()
def slurm_eval_cmd(self, global_step, ckpt_path):
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{global_step}"
cmd = (
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
f"singularity exec --nv --pid --writable-tmpfs --bind /storage:/storage "
f"{self.__image} "
f"bash ./evaluation/sh/install_deps_and_eval.sh {ckpt_path} {self.eval_output_path(global_step)} "
f"{self.__data_names} {self.__max_gen_tokens} {self.__prompt_type}"
)
return cmd
def eval_output_path(self, global_step):
return os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"eval_output",
f"globalstep{global_step}",
) )

View File

@ -136,7 +136,6 @@ class SlurmSchedulerClient(SchedulerClient):
wrap_cmd += f"--bind {launch_info.container_mounts} " wrap_cmd += f"--bind {launch_info.container_mounts} "
wrap_cmd += f"{launch_info.container_image} " wrap_cmd += f"{launch_info.container_image} "
wrap_cmd += "bash -c '{}'".format(cmd) wrap_cmd += "bash -c '{}'".format(cmd)
wrap_cmd = wrap_cmd.format(cmd)
launch_info.multiprog_content = f"0-{launch_info.n_jobsteps - 1} {wrap_cmd}\n" launch_info.multiprog_content = f"0-{launch_info.n_jobsteps - 1} {wrap_cmd}\n"
return launch_info return launch_info

View File

@ -289,9 +289,10 @@ class MasterWorker(worker_base.Worker):
mode=self.wandb_config.mode, mode=self.wandb_config.mode,
entity=self.wandb_config.entity, entity=self.wandb_config.entity,
project=self.wandb_config.project or constants.experiment_name(), project=self.wandb_config.project or constants.experiment_name(),
name=self.wandb_config.name or constants.trial_name(), name=self.wandb_config.name or f"{constants.trial_name()}_train",
job_type=self.wandb_config.job_type, job_type=self.wandb_config.job_type,
group=self.wandb_config.group, group=self.wandb_config.group
or f"{constants.experiment_name()}_{constants.trial_name()}",
notes=self.wandb_config.notes, notes=self.wandb_config.notes,
tags=self.wandb_config.tags, tags=self.wandb_config.tags,
config=self.wandb_config.config, config=self.wandb_config.config,
@ -299,6 +300,7 @@ class MasterWorker(worker_base.Worker):
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name() constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
), ),
force=True, force=True,
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
resume="allow", resume="allow",
settings=wandb.Settings(start_method="fork"), settings=wandb.Settings(start_method="fork"),
) )