mirror of https://github.com/inclusionAI/AReaL
PullRequest: 20 Fix bugs in auto evaluation
Merge branch mzy/refactor-eval of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/20?tab=diff Signed-off-by: 博惟 <bowei.fw@antgroup.com> * test * move evaluator to main process * . * clear codes * add docstring * . * separate wandb groups * . * handle eval error * add check for failed eval * refactor evaluator * refactor evaluator, fix master worker wandb login * .
This commit is contained in:
parent
26a48be73e
commit
3d8be914af
|
@ -5,7 +5,6 @@ import subprocess
|
|||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import wandb
|
||||
from rm_maj_eval import group_pred
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -110,7 +109,6 @@ def get_metrics(fname_pattern, tokenizer, is_greedy):
|
|||
"num_questions": len(lengths),
|
||||
}
|
||||
else:
|
||||
|
||||
return {
|
||||
"sample_length": np.mean(lengths),
|
||||
"sample_pass@1": pass_at_k(results.values(), 1),
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import wandb
|
||||
|
||||
|
@ -13,6 +16,138 @@ from realhf.base import cluster, constants, logging
|
|||
logger = logging.getLogger("AutomaticEvaluator", "colored")
|
||||
|
||||
|
||||
class EvaluationStepStatus(enum.Enum):
|
||||
PENDING = 0
|
||||
RUNNING = 1
|
||||
FAILED = 2
|
||||
DONE = 3
|
||||
LOGGED = 4
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvaluationStep:
|
||||
global_step: int
|
||||
status: EvaluationStepStatus
|
||||
start_time: Optional[float] = None
|
||||
ckpt_dir: Optional[str] = None
|
||||
process: Optional[subprocess.Popen] = None
|
||||
|
||||
@staticmethod
|
||||
def from_ckpt_dir(ckpt_dir):
|
||||
# NOTE: ckpt_dir should be absolute path
|
||||
if pathlib.Path(ckpt_dir).is_symlink():
|
||||
return None
|
||||
_dir = os.path.basename(ckpt_dir)
|
||||
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", _dir)
|
||||
if not match:
|
||||
return None
|
||||
_, _, global_step = map(int, match.groups())
|
||||
return EvaluationStep(
|
||||
global_step=global_step,
|
||||
status=EvaluationStepStatus.PENDING,
|
||||
ckpt_dir=ckpt_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_output_dir(output_dir):
|
||||
# NOTE: output_dir should be absolute path
|
||||
# Should only be called in recover.
|
||||
_dir = os.path.basename(output_dir)
|
||||
match = re.match(r"globalstep(\d+)", _dir)
|
||||
if not match:
|
||||
return None
|
||||
global_step = int(match.groups()[0])
|
||||
return EvaluationStep(
|
||||
global_step=global_step, status=EvaluationStepStatus.LOGGED
|
||||
)
|
||||
|
||||
@property
|
||||
def output_dir(self):
|
||||
return os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"eval_output",
|
||||
f"globalstep{self.global_step}",
|
||||
)
|
||||
|
||||
def slurm_eval_cmd(self, config: config_pkg.AutomaticEvaluator):
|
||||
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{self.global_step}"
|
||||
cmd = (
|
||||
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
||||
f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage "
|
||||
f"{config.eval_job_image or cluster.spec.gpu_image} "
|
||||
f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} "
|
||||
f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}"
|
||||
)
|
||||
return cmd
|
||||
|
||||
def submit(self, config: config_pkg.AutomaticEvaluator):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
log_file = open(os.path.join(self.output_dir, "output.log"), "w")
|
||||
if cluster.spec.cluster_type == "slurm":
|
||||
cmd = self.slurm_eval_cmd(config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"AutomaticEvaluator does only support slurm cluster."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Submitting evaluation job of checkpoint at {self.ckpt_dir} (globalstep{self.global_step}), "
|
||||
f"command: {cmd}"
|
||||
)
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=log_file,
|
||||
shell=True,
|
||||
)
|
||||
self.start_time = time.perf_counter()
|
||||
self.status = EvaluationStepStatus.RUNNING
|
||||
|
||||
def log(self, config: config_pkg.AutomaticEvaluator) -> bool:
|
||||
result_path = os.path.join(
|
||||
self.output_dir,
|
||||
f"math_eval_{config.max_gen_tokens}",
|
||||
f"aggregate_parallel_{config.prompt_type}.json",
|
||||
)
|
||||
# NOTE: If decoding json failed or not found,
|
||||
# evaluation step will be marked as failed.
|
||||
try:
|
||||
with open(result_path, "r") as fp:
|
||||
data = json.load(fp)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"JSON file {result_path} decoding failed.")
|
||||
self.status = EvaluationStepStatus.FAILED
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"JSON file {result_path} does not exist.")
|
||||
self.status = EvaluationStepStatus.FAILED
|
||||
return False
|
||||
|
||||
wandb_data = {}
|
||||
for data_name, d in data.items():
|
||||
for k, v in d.items():
|
||||
wandb_data[f"{data_name}_{k}"] = v
|
||||
wandb.log(wandb_data, step=self.global_step)
|
||||
self.status = EvaluationStepStatus.LOGGED
|
||||
logger.info(f"Logging eval result {wandb_data} to step {self.global_step}")
|
||||
return True
|
||||
|
||||
def check(self):
|
||||
assert self.process is not None
|
||||
result = self.process.poll()
|
||||
if not result is None:
|
||||
logger.info(
|
||||
f"Evaluation of checkpoint (globalstep{self.global_step}) is done, returncode={self.process.returncode}, "
|
||||
f"time passed {time.perf_counter() - self.start_time:.3f} s."
|
||||
)
|
||||
if self.process.returncode == 0:
|
||||
self.status = EvaluationStepStatus.DONE
|
||||
else:
|
||||
self.status = EvaluationStepStatus.FAILED
|
||||
|
||||
|
||||
class AutomaticEvaluator:
|
||||
|
||||
def __init__(
|
||||
|
@ -20,34 +155,31 @@ class AutomaticEvaluator:
|
|||
config: config_pkg.AutomaticEvaluator,
|
||||
wandb_config: config_pkg.WandBConfig,
|
||||
):
|
||||
self.__running_processes: Dict[int, subprocess.Popen] = {}
|
||||
self.__start_time = {}
|
||||
self.__done_steps = []
|
||||
self.__wandb_log_steps = []
|
||||
self.__pending_ckpts = {}
|
||||
self.__image = config.eval_job_image or cluster.spec.gpu_image
|
||||
self.__data_names = config.data_names
|
||||
self.__max_gen_tokens = config.max_gen_tokens
|
||||
self.__eval_steps: Dict[int, EvaluationStep] = {}
|
||||
self.__max_concurrent_jobs = config.max_concurrent_jobs
|
||||
self.__prompt_type = config.prompt_type
|
||||
self.__wandb_config = wandb_config
|
||||
self.__config = config
|
||||
self.__wandb_inited = False
|
||||
self.__wandb_initialized = False
|
||||
|
||||
# Check evaluated checkpoints by logs
|
||||
former_output_dir = os.path.join(
|
||||
# Check evaluated checkpoints by logs in recover
|
||||
# NOTE: All previous evaluation steps with output will be marked
|
||||
# as logged, even if it is not really logged in wandb.
|
||||
# This is because we do not know the status of evaluation jobs
|
||||
# submitted before recover.
|
||||
# Resubmiting or waiting for these jobs will probably result in
|
||||
# unexpected behaviors.
|
||||
output_parent = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"eval_output",
|
||||
)
|
||||
if os.path.exists(former_output_dir):
|
||||
for log_dir in os.listdir(former_output_dir):
|
||||
match = re.match(r"globalstep(\d+)", log_dir)
|
||||
if not match:
|
||||
continue
|
||||
global_step = int(match.groups()[0])
|
||||
self.__done_steps.append(global_step)
|
||||
if os.path.exists(output_parent):
|
||||
for output_dir in os.listdir(output_parent):
|
||||
output_dir = os.path.join(output_parent, output_dir)
|
||||
eval_step = EvaluationStep.from_output_dir(output_dir)
|
||||
if eval_step:
|
||||
self.__eval_steps[eval_step.global_step] = eval_step
|
||||
|
||||
logger.info(
|
||||
f"Initializing AutomaticEvaluator: \n"
|
||||
|
@ -55,10 +187,15 @@ class AutomaticEvaluator:
|
|||
f"data_names: {config.data_names}\n"
|
||||
f"max_gen_tokens: {config.max_gen_tokens}\n"
|
||||
f"max_concurrent_jobs: {config.max_concurrent_jobs}\n"
|
||||
f"Existing eval outputs for global steps: {self.__done_steps}"
|
||||
f"Existing eval outputs for global steps: "
|
||||
f"{list(self.__eval_steps.keys())}"
|
||||
)
|
||||
if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps:
|
||||
self.__eval_steps[0] = EvaluationStep(
|
||||
global_step=0,
|
||||
status=EvaluationStepStatus.PENDING,
|
||||
ckpt_dir=self.__config.initial_checkpoint_path,
|
||||
)
|
||||
if self.__config.initial_checkpoint_path and 0 not in self.__done_steps:
|
||||
self.__pending_ckpts[0] = self.__config.initial_checkpoint_path
|
||||
|
||||
if not cluster.spec.cluster_type == "slurm":
|
||||
raise NotImplementedError(
|
||||
|
@ -66,7 +203,10 @@ class AutomaticEvaluator:
|
|||
)
|
||||
|
||||
def __lazy_wandb_init(self):
|
||||
# Initializing wandb for evaluator
|
||||
# Initializing wandb for evaluator.
|
||||
# Here we use lazy init because if this wandb instance is launched
|
||||
# with wandb instance on master worker without a time interval,
|
||||
# one of them will fail.
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.__wandb_config.mode,
|
||||
|
@ -88,159 +228,80 @@ class AutomaticEvaluator:
|
|||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
||||
def __check_new_ckpts(self):
|
||||
save_path = os.path.join(
|
||||
def step(self):
|
||||
# Check whether a new evaluation step should be created
|
||||
ckpt_parent = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"actor",
|
||||
)
|
||||
if not os.path.exists(save_path):
|
||||
return
|
||||
for ckpt_dir in os.listdir(save_path):
|
||||
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", ckpt_dir)
|
||||
if not match:
|
||||
if os.path.exists(ckpt_parent):
|
||||
for ckpt_dir in os.listdir(ckpt_parent):
|
||||
ckpt_dir = os.path.join(ckpt_parent, ckpt_dir)
|
||||
eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir)
|
||||
if eval_step is None:
|
||||
continue
|
||||
_, _, global_step = map(int, match.groups())
|
||||
if not global_step in (
|
||||
list(self.__running_processes.keys())
|
||||
+ list(self.__pending_ckpts.keys())
|
||||
+ self.__done_steps
|
||||
):
|
||||
abs_ckpt_dir = os.path.join(save_path, ckpt_dir)
|
||||
if eval_step.global_step in self.__eval_steps:
|
||||
continue
|
||||
self.__eval_steps[eval_step.global_step] = eval_step
|
||||
logger.info(
|
||||
f"Found new checkpoint (globalstep{global_step}) at {abs_ckpt_dir}"
|
||||
)
|
||||
self.__pending_ckpts[global_step] = os.path.join(
|
||||
save_path, abs_ckpt_dir
|
||||
f"Found new checkpoint (globalstep{eval_step.global_step}) "
|
||||
f"at {ckpt_dir}"
|
||||
)
|
||||
|
||||
def __check_and_maybe_submit_jobs(self):
|
||||
for global_step, process in self.__running_processes.items():
|
||||
result = process.poll()
|
||||
if not result is None:
|
||||
self.__done_steps.append(global_step)
|
||||
start_time = self.__start_time[global_step]
|
||||
logger.info(
|
||||
f"Evaluation of checkpoint (globalstep{global_step}) is done, returncode={process.returncode}, "
|
||||
f"time passed {time.perf_counter() - start_time:.3f} s."
|
||||
)
|
||||
|
||||
for done in self.__done_steps:
|
||||
if done in self.__running_processes:
|
||||
self.__running_processes.pop(done)
|
||||
self.__start_time.pop(done)
|
||||
|
||||
submitted = []
|
||||
# Jobs should be submitted by the order of global steps
|
||||
ordered_steps = sorted(self.__pending_ckpts.keys())
|
||||
for global_step in ordered_steps:
|
||||
ckpt_path = self.__pending_ckpts[global_step]
|
||||
if len(self.__running_processes) >= self.__max_concurrent_jobs:
|
||||
return
|
||||
self.__submit_one(global_step, ckpt_path)
|
||||
submitted.append(global_step)
|
||||
|
||||
for global_step in submitted:
|
||||
self.__pending_ckpts.pop(global_step)
|
||||
|
||||
def __submit_one(self, global_step, ckpt_path):
|
||||
output_path = self.eval_output_path(global_step)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
log_file = open(os.path.join(output_path, "output.log"), "w")
|
||||
if cluster.spec.cluster_type == "slurm":
|
||||
cmd = self.slurm_eval_cmd(global_step, ckpt_path)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"AutomaticEvaluator does only support slurm cluster."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Submitting evaluation job of checkpoint at {ckpt_path} (globalstep{global_step}), "
|
||||
f"command: {cmd}"
|
||||
)
|
||||
self.__running_processes[global_step] = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=log_file,
|
||||
shell=True,
|
||||
)
|
||||
self.__start_time[global_step] = time.perf_counter()
|
||||
|
||||
def __maybe_parse_and_log_to_wandb(self):
|
||||
# Note that after recover, all previous done steps will be
|
||||
# logged again in case some data points are missing.
|
||||
# If the data point is already logged, wandb will raise
|
||||
# a warning.
|
||||
to_log = list(
|
||||
# Submit pending evaluation step
|
||||
if self.__running_jobs < self.__max_concurrent_jobs:
|
||||
# Submit in global_step order
|
||||
pending_steps = list(
|
||||
filter(
|
||||
lambda x: x not in self.__wandb_log_steps,
|
||||
(
|
||||
self.__done_steps
|
||||
+ list(self.__running_processes.keys())
|
||||
+ list(self.__pending_ckpts.keys())
|
||||
),
|
||||
lambda x: self.__eval_steps[x].status
|
||||
== EvaluationStepStatus.PENDING,
|
||||
self.__eval_steps.keys(),
|
||||
)
|
||||
)
|
||||
while to_log:
|
||||
# The wandb should always log the minimal global step
|
||||
# whose eval job has been submitted but not logged.
|
||||
# If this minimal step is not logged, other steps should wait.
|
||||
global_step = min(to_log)
|
||||
result_path = os.path.join(
|
||||
self.eval_output_path(global_step),
|
||||
f"math_eval_{self.__max_gen_tokens}",
|
||||
f"aggregate_parallel_{self.__prompt_type}.json",
|
||||
)
|
||||
if not os.path.exists(result_path):
|
||||
break
|
||||
if pending_steps:
|
||||
min_pending = min(pending_steps)
|
||||
self.__eval_steps[min_pending].submit(self.__config)
|
||||
|
||||
if not self.__wandb_inited:
|
||||
# Check if any eval job is done or failed
|
||||
running_steps = filter(
|
||||
lambda x: self.__eval_steps[x].status == EvaluationStepStatus.RUNNING,
|
||||
self.__eval_steps.keys(),
|
||||
)
|
||||
for global_step in running_steps:
|
||||
self.__eval_steps[global_step].check()
|
||||
|
||||
# Check whether the **minimal global step**, not logged or failed, is done,
|
||||
# and log this step to wandb if done.
|
||||
# NOTE: LOGGED and FAILED steps have identical behaviors now.
|
||||
# But in future versions that supports multi-node eval they could be different.
|
||||
log_steps = list(
|
||||
filter(
|
||||
lambda x: self.__eval_steps[x].status
|
||||
not in [
|
||||
EvaluationStepStatus.LOGGED,
|
||||
EvaluationStepStatus.FAILED,
|
||||
],
|
||||
self.__eval_steps.keys(),
|
||||
)
|
||||
)
|
||||
if log_steps:
|
||||
log_step = min(log_steps)
|
||||
if self.__eval_steps[log_step].status == EvaluationStepStatus.DONE:
|
||||
if not self.__wandb_initialized:
|
||||
self.__lazy_wandb_init()
|
||||
self.__wandb_inited = True
|
||||
self.__wandb_initialized = True
|
||||
self.__eval_steps[log_step].log(self.__config)
|
||||
|
||||
try:
|
||||
with open(result_path, "r") as fp:
|
||||
data = json.load(fp)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"JSON decoding for eval result in {result_path} failed")
|
||||
continue
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
f"{result_path} not found, but the eval job is done. "
|
||||
"Maybe the eval job abnormally exited and did not output the result."
|
||||
@property
|
||||
def __running_jobs(self):
|
||||
return len(
|
||||
list(
|
||||
filter(
|
||||
lambda x: self.__eval_steps[x].status
|
||||
== EvaluationStepStatus.RUNNING,
|
||||
self.__eval_steps.keys(),
|
||||
)
|
||||
continue
|
||||
wandb_data = {}
|
||||
for data_name, d in data.items():
|
||||
for k, v in d.items():
|
||||
wandb_data[f"{data_name}_{k}"] = v
|
||||
wandb.log(wandb_data, step=global_step)
|
||||
logger.info(f"Logging eval result {wandb_data} to step {global_step}")
|
||||
self.__wandb_log_steps.append(global_step)
|
||||
to_log.remove(global_step)
|
||||
|
||||
def step(self):
|
||||
self.__check_new_ckpts()
|
||||
self.__check_and_maybe_submit_jobs()
|
||||
self.__maybe_parse_and_log_to_wandb()
|
||||
|
||||
def slurm_eval_cmd(self, global_step, ckpt_path):
|
||||
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{global_step}"
|
||||
cmd = (
|
||||
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
||||
f"singularity exec --nv --pid --writable-tmpfs --bind /storage:/storage "
|
||||
f"{self.__image} "
|
||||
f"bash ./evaluation/sh/install_deps_and_eval.sh {ckpt_path} {self.eval_output_path(global_step)} "
|
||||
f"{self.__data_names} {self.__max_gen_tokens} {self.__prompt_type}"
|
||||
)
|
||||
return cmd
|
||||
|
||||
def eval_output_path(self, global_step):
|
||||
return os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"eval_output",
|
||||
f"globalstep{global_step}",
|
||||
)
|
||||
|
|
|
@ -136,7 +136,6 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
wrap_cmd += f"--bind {launch_info.container_mounts} "
|
||||
wrap_cmd += f"{launch_info.container_image} "
|
||||
wrap_cmd += "bash -c '{}'".format(cmd)
|
||||
wrap_cmd = wrap_cmd.format(cmd)
|
||||
launch_info.multiprog_content = f"0-{launch_info.n_jobsteps - 1} {wrap_cmd}\n"
|
||||
return launch_info
|
||||
|
||||
|
|
|
@ -289,9 +289,10 @@ class MasterWorker(worker_base.Worker):
|
|||
mode=self.wandb_config.mode,
|
||||
entity=self.wandb_config.entity,
|
||||
project=self.wandb_config.project or constants.experiment_name(),
|
||||
name=self.wandb_config.name or constants.trial_name(),
|
||||
name=self.wandb_config.name or f"{constants.trial_name()}_train",
|
||||
job_type=self.wandb_config.job_type,
|
||||
group=self.wandb_config.group,
|
||||
group=self.wandb_config.group
|
||||
or f"{constants.experiment_name()}_{constants.trial_name()}",
|
||||
notes=self.wandb_config.notes,
|
||||
tags=self.wandb_config.tags,
|
||||
config=self.wandb_config.config,
|
||||
|
@ -299,6 +300,7 @@ class MasterWorker(worker_base.Worker):
|
|||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
||||
),
|
||||
force=True,
|
||||
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue