mirror of https://github.com/inclusionAI/AReaL
350 lines
13 KiB
Python
350 lines
13 KiB
Python
import dataclasses
|
|
import enum
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import subprocess
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
import swanlab
|
|
import wandb
|
|
|
|
import realhf.api.core.system_api as config_pkg
|
|
from realhf.base import cluster, constants, logging
|
|
|
|
logger = logging.getLogger("AutomaticEvaluator", "colored")
|
|
|
|
|
|
class EvaluationStepStatus(enum.Enum):
|
|
PENDING = 0
|
|
RUNNING = 1
|
|
FAILED = 2
|
|
DONE = 3
|
|
LOGGED = 4
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class EvaluationStep:
|
|
global_step: int
|
|
status: EvaluationStepStatus
|
|
start_time: Optional[float] = None
|
|
ckpt_dir: Optional[str] = None
|
|
process: Optional[subprocess.Popen] = None
|
|
|
|
@staticmethod
|
|
def from_ckpt_dir(ckpt_dir):
|
|
# NOTE: ckpt_dir should be absolute path
|
|
if pathlib.Path(ckpt_dir).is_symlink():
|
|
return None
|
|
_dir = os.path.basename(ckpt_dir)
|
|
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", _dir)
|
|
if not match:
|
|
return None
|
|
_, _, global_step = map(int, match.groups())
|
|
return EvaluationStep(
|
|
global_step=global_step,
|
|
status=EvaluationStepStatus.PENDING,
|
|
ckpt_dir=ckpt_dir,
|
|
)
|
|
|
|
@staticmethod
|
|
def from_output_dir(output_dir):
|
|
# NOTE: output_dir should be absolute path
|
|
# Should only be called in recover.
|
|
_dir = os.path.basename(output_dir)
|
|
match = re.match(r"globalstep(\d+)", _dir)
|
|
if not match:
|
|
return None
|
|
global_step = int(match.groups()[0])
|
|
return EvaluationStep(
|
|
global_step=global_step, status=EvaluationStepStatus.LOGGED
|
|
)
|
|
|
|
@property
|
|
def output_dir(self):
|
|
return os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"eval_output",
|
|
f"globalstep{self.global_step}",
|
|
)
|
|
|
|
def slurm_eval_cmd(self, config: config_pkg.AutomaticEvaluator):
|
|
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{self.global_step}"
|
|
cmd = (
|
|
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
|
f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage "
|
|
f"{config.eval_job_image or cluster.spec.gpu_image} "
|
|
f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} "
|
|
f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}"
|
|
)
|
|
return cmd
|
|
|
|
def submit(self, config: config_pkg.AutomaticEvaluator):
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
log_file = open(os.path.join(self.output_dir, "output.log"), "w")
|
|
if cluster.spec.cluster_type == "slurm":
|
|
cmd = self.slurm_eval_cmd(config)
|
|
else:
|
|
raise NotImplementedError(
|
|
"AutomaticEvaluator does only support slurm cluster."
|
|
)
|
|
|
|
logger.info(
|
|
f"Submitting evaluation job of checkpoint at {self.ckpt_dir} (globalstep{self.global_step}), "
|
|
f"command: {cmd}"
|
|
)
|
|
self.process = subprocess.Popen(
|
|
cmd,
|
|
stdout=log_file,
|
|
stderr=log_file,
|
|
shell=True,
|
|
)
|
|
self.start_time = time.perf_counter()
|
|
self.status = EvaluationStepStatus.RUNNING
|
|
|
|
def log(self, config: config_pkg.AutomaticEvaluator) -> bool:
|
|
result_path = os.path.join(
|
|
self.output_dir,
|
|
f"math_eval_{config.max_gen_tokens}",
|
|
f"aggregate_parallel_{config.prompt_type}.json",
|
|
)
|
|
# NOTE: If decoding json failed or not found,
|
|
# evaluation step will be marked as failed.
|
|
try:
|
|
with open(result_path, "r") as fp:
|
|
data = json.load(fp)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"JSON file {result_path} decoding failed.")
|
|
self.status = EvaluationStepStatus.FAILED
|
|
return False
|
|
except FileNotFoundError:
|
|
logger.warning(f"JSON file {result_path} does not exist.")
|
|
self.status = EvaluationStepStatus.FAILED
|
|
return False
|
|
|
|
log_data = {}
|
|
for data_name, d in data.items():
|
|
for k, v in d.items():
|
|
log_data[f"{data_name}_{k}"] = v
|
|
wandb.log(log_data, step=self.global_step)
|
|
swanlab.log(log_data, step=self.global_step)
|
|
self.status = EvaluationStepStatus.LOGGED
|
|
logger.info(f"Logging eval result {log_data} to step {self.global_step}")
|
|
|
|
return True
|
|
|
|
def check(self):
|
|
assert self.process is not None
|
|
result = self.process.poll()
|
|
if not result is None:
|
|
logger.info(
|
|
f"Evaluation of checkpoint (globalstep{self.global_step}) is done, returncode={self.process.returncode}, "
|
|
f"time passed {time.perf_counter() - self.start_time:.3f} s."
|
|
)
|
|
if self.process.returncode == 0:
|
|
self.status = EvaluationStepStatus.DONE
|
|
else:
|
|
self.status = EvaluationStepStatus.FAILED
|
|
|
|
|
|
class AutomaticEvaluator:
|
|
|
|
def __init__(
|
|
self,
|
|
config: config_pkg.AutomaticEvaluator,
|
|
wandb_config: config_pkg.WandBConfig,
|
|
swanlab_config: config_pkg.SwanlabConfig,
|
|
):
|
|
self.__eval_steps: Dict[int, EvaluationStep] = {}
|
|
self.__max_concurrent_jobs = config.max_concurrent_jobs
|
|
self.__wandb_config = wandb_config
|
|
self.__swanlab_config = swanlab_config
|
|
self.__config = config
|
|
self.__wandb_initialized = False
|
|
self.__swanlab_initialized = False
|
|
# Check evaluated checkpoints by logs in recover
|
|
# NOTE: All previous evaluation steps with output will be marked
|
|
# as logged, even if it is not really logged in wandb.
|
|
# This is because we do not know the status of evaluation jobs
|
|
# submitted before recover.
|
|
# Resubmiting or waiting for these jobs will probably result in
|
|
# unexpected behaviors.
|
|
output_parent = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"eval_output",
|
|
)
|
|
if os.path.exists(output_parent):
|
|
for output_dir in os.listdir(output_parent):
|
|
output_dir = os.path.join(output_parent, output_dir)
|
|
eval_step = EvaluationStep.from_output_dir(output_dir)
|
|
if eval_step:
|
|
self.__eval_steps[eval_step.global_step] = eval_step
|
|
|
|
logger.info(
|
|
f"Initializing AutomaticEvaluator: \n"
|
|
f"eval_job_image: {config.eval_job_image}\n"
|
|
f"data_names: {config.data_names}\n"
|
|
f"max_gen_tokens: {config.max_gen_tokens}\n"
|
|
f"max_concurrent_jobs: {config.max_concurrent_jobs}\n"
|
|
f"Existing eval outputs for global steps: "
|
|
f"{list(self.__eval_steps.keys())}"
|
|
)
|
|
if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps:
|
|
self.__eval_steps[0] = EvaluationStep(
|
|
global_step=0,
|
|
status=EvaluationStepStatus.PENDING,
|
|
ckpt_dir=self.__config.initial_checkpoint_path,
|
|
)
|
|
|
|
if not cluster.spec.cluster_type == "slurm":
|
|
raise NotImplementedError(
|
|
"Currently only support automatic evaluation for slurm"
|
|
)
|
|
|
|
def __lazy_wandb_init(self):
|
|
# Initializing wandb for evaluator.
|
|
# Here we use lazy init because if this wandb instance is launched
|
|
# with wandb instance on master worker without a time interval,
|
|
# one of them will fail.
|
|
wandb.login()
|
|
wandb.init(
|
|
mode=self.__wandb_config.mode,
|
|
entity=self.__wandb_config.entity,
|
|
project=self.__wandb_config.project or constants.experiment_name(),
|
|
name=self.__wandb_config.name or f"{constants.trial_name()}_eval",
|
|
job_type=self.__wandb_config.job_type,
|
|
group=self.__wandb_config.group
|
|
or f"{constants.experiment_name()}_{constants.trial_name()}",
|
|
notes=self.__wandb_config.notes,
|
|
tags=self.__wandb_config.tags,
|
|
config=self.__wandb_config.config,
|
|
dir=os.path.join(
|
|
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
|
),
|
|
force=True,
|
|
id=f"{constants.experiment_name()}_{constants.trial_name()}_eval",
|
|
resume="allow",
|
|
settings=wandb.Settings(start_method="fork"),
|
|
)
|
|
|
|
def __lazy_swanlab_init(self):
|
|
if self.__swanlab_config.api_key:
|
|
swanlab.login(self.__swanlab_config.api_key)
|
|
if self.swanlab_config.config is None:
|
|
import yaml
|
|
|
|
with open(
|
|
os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"config.yaml",
|
|
),
|
|
"r",
|
|
) as f:
|
|
__config = yaml.safe_load(f)
|
|
else:
|
|
__config = self.swanlab_config.config
|
|
__config["FRAMEWORK"] = "AReaL"
|
|
swanlab.init(
|
|
project=self.__swanlab_config.project or constants.experiment_name(),
|
|
experiment_name=self.__swanlab_config.name
|
|
or f"{constants.trial_name()}_eval",
|
|
config=__config,
|
|
logdir=self.__swanlab_config.logdir
|
|
or os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"swanlab",
|
|
),
|
|
mode=self.__swanlab_config.mode,
|
|
)
|
|
|
|
def step(self):
|
|
# Check whether a new evaluation step should be created
|
|
ckpt_parent = os.path.join(
|
|
constants.MODEL_SAVE_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"actor",
|
|
)
|
|
if os.path.exists(ckpt_parent):
|
|
for ckpt_dir in os.listdir(ckpt_parent):
|
|
ckpt_dir = os.path.join(ckpt_parent, ckpt_dir)
|
|
eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir)
|
|
if eval_step is None:
|
|
continue
|
|
if eval_step.global_step in self.__eval_steps:
|
|
continue
|
|
self.__eval_steps[eval_step.global_step] = eval_step
|
|
logger.info(
|
|
f"Found new checkpoint (globalstep{eval_step.global_step}) "
|
|
f"at {ckpt_dir}"
|
|
)
|
|
|
|
# Submit pending evaluation step
|
|
if self.__running_jobs < self.__max_concurrent_jobs:
|
|
# Submit in global_step order
|
|
pending_steps = list(
|
|
filter(
|
|
lambda x: self.__eval_steps[x].status
|
|
== EvaluationStepStatus.PENDING,
|
|
self.__eval_steps.keys(),
|
|
)
|
|
)
|
|
if pending_steps:
|
|
min_pending = min(pending_steps)
|
|
self.__eval_steps[min_pending].submit(self.__config)
|
|
|
|
# Check if any eval job is done or failed
|
|
running_steps = filter(
|
|
lambda x: self.__eval_steps[x].status == EvaluationStepStatus.RUNNING,
|
|
self.__eval_steps.keys(),
|
|
)
|
|
for global_step in running_steps:
|
|
self.__eval_steps[global_step].check()
|
|
|
|
# Check whether the **minimal global step**, not logged or failed, is done,
|
|
# and log this step to wandb if done.
|
|
# NOTE: LOGGED and FAILED steps have identical behaviors now.
|
|
# But in future versions that supports multi-node eval they could be different.
|
|
log_steps = list(
|
|
filter(
|
|
lambda x: self.__eval_steps[x].status
|
|
not in [
|
|
EvaluationStepStatus.LOGGED,
|
|
EvaluationStepStatus.FAILED,
|
|
],
|
|
self.__eval_steps.keys(),
|
|
)
|
|
)
|
|
if log_steps:
|
|
log_step = min(log_steps)
|
|
if self.__eval_steps[log_step].status == EvaluationStepStatus.DONE:
|
|
if not self.__wandb_initialized:
|
|
self.__lazy_wandb_init()
|
|
self.__wandb_initialized = True
|
|
if not self.__swanlab_initialized:
|
|
self.__lazy_swanlab_init()
|
|
self.__swanlab_initialized = True
|
|
self.__eval_steps[log_step].log(self.__config)
|
|
|
|
@property
|
|
def __running_jobs(self):
|
|
return len(
|
|
list(
|
|
filter(
|
|
lambda x: self.__eval_steps[x].status
|
|
== EvaluationStepStatus.RUNNING,
|
|
self.__eval_steps.keys(),
|
|
)
|
|
)
|
|
)
|