diff --git a/evaluation/eval_and_aggregate.py b/evaluation/eval_and_aggregate.py index be5d042..2a14a71 100644 --- a/evaluation/eval_and_aggregate.py +++ b/evaluation/eval_and_aggregate.py @@ -5,7 +5,6 @@ import subprocess from glob import glob import numpy as np -import wandb from rm_maj_eval import group_pred from tqdm import tqdm from transformers import AutoTokenizer @@ -110,7 +109,6 @@ def get_metrics(fname_pattern, tokenizer, is_greedy): "num_questions": len(lengths), } else: - return { "sample_length": np.mean(lengths), "sample_pass@1": pass_at_k(results.values(), 1), diff --git a/realhf/scheduler/evaluator.py b/realhf/scheduler/evaluator.py index 997a0fe..a24029e 100644 --- a/realhf/scheduler/evaluator.py +++ b/realhf/scheduler/evaluator.py @@ -1,9 +1,12 @@ +import dataclasses +import enum import json import os +import pathlib import re import subprocess import time -from typing import Dict +from typing import Dict, Optional import wandb @@ -13,6 +16,138 @@ from realhf.base import cluster, constants, logging logger = logging.getLogger("AutomaticEvaluator", "colored") +class EvaluationStepStatus(enum.Enum): + PENDING = 0 + RUNNING = 1 + FAILED = 2 + DONE = 3 + LOGGED = 4 + + +@dataclasses.dataclass +class EvaluationStep: + global_step: int + status: EvaluationStepStatus + start_time: Optional[float] = None + ckpt_dir: Optional[str] = None + process: Optional[subprocess.Popen] = None + + @staticmethod + def from_ckpt_dir(ckpt_dir): + # NOTE: ckpt_dir should be absolute path + if pathlib.Path(ckpt_dir).is_symlink(): + return None + _dir = os.path.basename(ckpt_dir) + match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", _dir) + if not match: + return None + _, _, global_step = map(int, match.groups()) + return EvaluationStep( + global_step=global_step, + status=EvaluationStepStatus.PENDING, + ckpt_dir=ckpt_dir, + ) + + @staticmethod + def from_output_dir(output_dir): + # NOTE: output_dir should be absolute path + # Should only be called in recover. + _dir = os.path.basename(output_dir) + match = re.match(r"globalstep(\d+)", _dir) + if not match: + return None + global_step = int(match.groups()[0]) + return EvaluationStep( + global_step=global_step, status=EvaluationStepStatus.LOGGED + ) + + @property + def output_dir(self): + return os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "eval_output", + f"globalstep{self.global_step}", + ) + + def slurm_eval_cmd(self, config: config_pkg.AutomaticEvaluator): + slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{self.global_step}" + cmd = ( + f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G " + f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage " + f"{config.eval_job_image or cluster.spec.gpu_image} " + f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} " + f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}" + ) + return cmd + + def submit(self, config: config_pkg.AutomaticEvaluator): + os.makedirs(self.output_dir, exist_ok=True) + log_file = open(os.path.join(self.output_dir, "output.log"), "w") + if cluster.spec.cluster_type == "slurm": + cmd = self.slurm_eval_cmd(config) + else: + raise NotImplementedError( + "AutomaticEvaluator does only support slurm cluster." + ) + + logger.info( + f"Submitting evaluation job of checkpoint at {self.ckpt_dir} (globalstep{self.global_step}), " + f"command: {cmd}" + ) + self.process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=log_file, + shell=True, + ) + self.start_time = time.perf_counter() + self.status = EvaluationStepStatus.RUNNING + + def log(self, config: config_pkg.AutomaticEvaluator) -> bool: + result_path = os.path.join( + self.output_dir, + f"math_eval_{config.max_gen_tokens}", + f"aggregate_parallel_{config.prompt_type}.json", + ) + # NOTE: If decoding json failed or not found, + # evaluation step will be marked as failed. + try: + with open(result_path, "r") as fp: + data = json.load(fp) + except json.JSONDecodeError: + logger.warning(f"JSON file {result_path} decoding failed.") + self.status = EvaluationStepStatus.FAILED + return False + except FileNotFoundError: + logger.warning(f"JSON file {result_path} does not exist.") + self.status = EvaluationStepStatus.FAILED + return False + + wandb_data = {} + for data_name, d in data.items(): + for k, v in d.items(): + wandb_data[f"{data_name}_{k}"] = v + wandb.log(wandb_data, step=self.global_step) + self.status = EvaluationStepStatus.LOGGED + logger.info(f"Logging eval result {wandb_data} to step {self.global_step}") + return True + + def check(self): + assert self.process is not None + result = self.process.poll() + if not result is None: + logger.info( + f"Evaluation of checkpoint (globalstep{self.global_step}) is done, returncode={self.process.returncode}, " + f"time passed {time.perf_counter() - self.start_time:.3f} s." + ) + if self.process.returncode == 0: + self.status = EvaluationStepStatus.DONE + else: + self.status = EvaluationStepStatus.FAILED + + class AutomaticEvaluator: def __init__( @@ -20,34 +155,31 @@ class AutomaticEvaluator: config: config_pkg.AutomaticEvaluator, wandb_config: config_pkg.WandBConfig, ): - self.__running_processes: Dict[int, subprocess.Popen] = {} - self.__start_time = {} - self.__done_steps = [] - self.__wandb_log_steps = [] - self.__pending_ckpts = {} - self.__image = config.eval_job_image or cluster.spec.gpu_image - self.__data_names = config.data_names - self.__max_gen_tokens = config.max_gen_tokens + self.__eval_steps: Dict[int, EvaluationStep] = {} self.__max_concurrent_jobs = config.max_concurrent_jobs - self.__prompt_type = config.prompt_type self.__wandb_config = wandb_config self.__config = config - self.__wandb_inited = False + self.__wandb_initialized = False - # Check evaluated checkpoints by logs - former_output_dir = os.path.join( + # Check evaluated checkpoints by logs in recover + # NOTE: All previous evaluation steps with output will be marked + # as logged, even if it is not really logged in wandb. + # This is because we do not know the status of evaluation jobs + # submitted before recover. + # Resubmiting or waiting for these jobs will probably result in + # unexpected behaviors. + output_parent = os.path.join( constants.LOG_ROOT, constants.experiment_name(), constants.trial_name(), "eval_output", ) - if os.path.exists(former_output_dir): - for log_dir in os.listdir(former_output_dir): - match = re.match(r"globalstep(\d+)", log_dir) - if not match: - continue - global_step = int(match.groups()[0]) - self.__done_steps.append(global_step) + if os.path.exists(output_parent): + for output_dir in os.listdir(output_parent): + output_dir = os.path.join(output_parent, output_dir) + eval_step = EvaluationStep.from_output_dir(output_dir) + if eval_step: + self.__eval_steps[eval_step.global_step] = eval_step logger.info( f"Initializing AutomaticEvaluator: \n" @@ -55,10 +187,15 @@ class AutomaticEvaluator: f"data_names: {config.data_names}\n" f"max_gen_tokens: {config.max_gen_tokens}\n" f"max_concurrent_jobs: {config.max_concurrent_jobs}\n" - f"Existing eval outputs for global steps: {self.__done_steps}" + f"Existing eval outputs for global steps: " + f"{list(self.__eval_steps.keys())}" ) - if self.__config.initial_checkpoint_path and 0 not in self.__done_steps: - self.__pending_ckpts[0] = self.__config.initial_checkpoint_path + if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps: + self.__eval_steps[0] = EvaluationStep( + global_step=0, + status=EvaluationStepStatus.PENDING, + ckpt_dir=self.__config.initial_checkpoint_path, + ) if not cluster.spec.cluster_type == "slurm": raise NotImplementedError( @@ -66,7 +203,10 @@ class AutomaticEvaluator: ) def __lazy_wandb_init(self): - # Initializing wandb for evaluator + # Initializing wandb for evaluator. + # Here we use lazy init because if this wandb instance is launched + # with wandb instance on master worker without a time interval, + # one of them will fail. wandb.login() wandb.init( mode=self.__wandb_config.mode, @@ -88,159 +228,80 @@ class AutomaticEvaluator: settings=wandb.Settings(start_method="fork"), ) - def __check_new_ckpts(self): - save_path = os.path.join( + def step(self): + # Check whether a new evaluation step should be created + ckpt_parent = os.path.join( constants.MODEL_SAVE_ROOT, constants.experiment_name(), constants.trial_name(), "actor", ) - if not os.path.exists(save_path): - return - for ckpt_dir in os.listdir(save_path): - match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", ckpt_dir) - if not match: - continue - _, _, global_step = map(int, match.groups()) - if not global_step in ( - list(self.__running_processes.keys()) - + list(self.__pending_ckpts.keys()) - + self.__done_steps - ): - abs_ckpt_dir = os.path.join(save_path, ckpt_dir) + if os.path.exists(ckpt_parent): + for ckpt_dir in os.listdir(ckpt_parent): + ckpt_dir = os.path.join(ckpt_parent, ckpt_dir) + eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir) + if eval_step is None: + continue + if eval_step.global_step in self.__eval_steps: + continue + self.__eval_steps[eval_step.global_step] = eval_step logger.info( - f"Found new checkpoint (globalstep{global_step}) at {abs_ckpt_dir}" - ) - self.__pending_ckpts[global_step] = os.path.join( - save_path, abs_ckpt_dir + f"Found new checkpoint (globalstep{eval_step.global_step}) " + f"at {ckpt_dir}" ) - def __check_and_maybe_submit_jobs(self): - for global_step, process in self.__running_processes.items(): - result = process.poll() - if not result is None: - self.__done_steps.append(global_step) - start_time = self.__start_time[global_step] - logger.info( - f"Evaluation of checkpoint (globalstep{global_step}) is done, returncode={process.returncode}, " - f"time passed {time.perf_counter() - start_time:.3f} s." + # Submit pending evaluation step + if self.__running_jobs < self.__max_concurrent_jobs: + # Submit in global_step order + pending_steps = list( + filter( + lambda x: self.__eval_steps[x].status + == EvaluationStepStatus.PENDING, + self.__eval_steps.keys(), ) - - for done in self.__done_steps: - if done in self.__running_processes: - self.__running_processes.pop(done) - self.__start_time.pop(done) - - submitted = [] - # Jobs should be submitted by the order of global steps - ordered_steps = sorted(self.__pending_ckpts.keys()) - for global_step in ordered_steps: - ckpt_path = self.__pending_ckpts[global_step] - if len(self.__running_processes) >= self.__max_concurrent_jobs: - return - self.__submit_one(global_step, ckpt_path) - submitted.append(global_step) - - for global_step in submitted: - self.__pending_ckpts.pop(global_step) - - def __submit_one(self, global_step, ckpt_path): - output_path = self.eval_output_path(global_step) - os.makedirs(output_path, exist_ok=True) - log_file = open(os.path.join(output_path, "output.log"), "w") - if cluster.spec.cluster_type == "slurm": - cmd = self.slurm_eval_cmd(global_step, ckpt_path) - else: - raise NotImplementedError( - "AutomaticEvaluator does only support slurm cluster." ) + if pending_steps: + min_pending = min(pending_steps) + self.__eval_steps[min_pending].submit(self.__config) - logger.info( - f"Submitting evaluation job of checkpoint at {ckpt_path} (globalstep{global_step}), " - f"command: {cmd}" + # Check if any eval job is done or failed + running_steps = filter( + lambda x: self.__eval_steps[x].status == EvaluationStepStatus.RUNNING, + self.__eval_steps.keys(), ) - self.__running_processes[global_step] = subprocess.Popen( - cmd, - stdout=log_file, - stderr=log_file, - shell=True, - ) - self.__start_time[global_step] = time.perf_counter() + for global_step in running_steps: + self.__eval_steps[global_step].check() - def __maybe_parse_and_log_to_wandb(self): - # Note that after recover, all previous done steps will be - # logged again in case some data points are missing. - # If the data point is already logged, wandb will raise - # a warning. - to_log = list( + # Check whether the **minimal global step**, not logged or failed, is done, + # and log this step to wandb if done. + # NOTE: LOGGED and FAILED steps have identical behaviors now. + # But in future versions that supports multi-node eval they could be different. + log_steps = list( filter( - lambda x: x not in self.__wandb_log_steps, - ( - self.__done_steps - + list(self.__running_processes.keys()) - + list(self.__pending_ckpts.keys()) - ), + lambda x: self.__eval_steps[x].status + not in [ + EvaluationStepStatus.LOGGED, + EvaluationStepStatus.FAILED, + ], + self.__eval_steps.keys(), ) ) - while to_log: - # The wandb should always log the minimal global step - # whose eval job has been submitted but not logged. - # If this minimal step is not logged, other steps should wait. - global_step = min(to_log) - result_path = os.path.join( - self.eval_output_path(global_step), - f"math_eval_{self.__max_gen_tokens}", - f"aggregate_parallel_{self.__prompt_type}.json", - ) - if not os.path.exists(result_path): - break + if log_steps: + log_step = min(log_steps) + if self.__eval_steps[log_step].status == EvaluationStepStatus.DONE: + if not self.__wandb_initialized: + self.__lazy_wandb_init() + self.__wandb_initialized = True + self.__eval_steps[log_step].log(self.__config) - if not self.__wandb_inited: - self.__lazy_wandb_init() - self.__wandb_inited = True - - try: - with open(result_path, "r") as fp: - data = json.load(fp) - except json.JSONDecodeError: - logger.warning(f"JSON decoding for eval result in {result_path} failed") - continue - except FileNotFoundError: - logger.warning( - f"{result_path} not found, but the eval job is done. " - "Maybe the eval job abnormally exited and did not output the result." + @property + def __running_jobs(self): + return len( + list( + filter( + lambda x: self.__eval_steps[x].status + == EvaluationStepStatus.RUNNING, + self.__eval_steps.keys(), ) - continue - wandb_data = {} - for data_name, d in data.items(): - for k, v in d.items(): - wandb_data[f"{data_name}_{k}"] = v - wandb.log(wandb_data, step=global_step) - logger.info(f"Logging eval result {wandb_data} to step {global_step}") - self.__wandb_log_steps.append(global_step) - to_log.remove(global_step) - - def step(self): - self.__check_new_ckpts() - self.__check_and_maybe_submit_jobs() - self.__maybe_parse_and_log_to_wandb() - - def slurm_eval_cmd(self, global_step, ckpt_path): - slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{global_step}" - cmd = ( - f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G " - f"singularity exec --nv --pid --writable-tmpfs --bind /storage:/storage " - f"{self.__image} " - f"bash ./evaluation/sh/install_deps_and_eval.sh {ckpt_path} {self.eval_output_path(global_step)} " - f"{self.__data_names} {self.__max_gen_tokens} {self.__prompt_type}" - ) - return cmd - - def eval_output_path(self, global_step): - return os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "eval_output", - f"globalstep{global_step}", + ) ) diff --git a/realhf/scheduler/slurm/client.py b/realhf/scheduler/slurm/client.py index 43234e8..b787b70 100644 --- a/realhf/scheduler/slurm/client.py +++ b/realhf/scheduler/slurm/client.py @@ -136,7 +136,6 @@ class SlurmSchedulerClient(SchedulerClient): wrap_cmd += f"--bind {launch_info.container_mounts} " wrap_cmd += f"{launch_info.container_image} " wrap_cmd += "bash -c '{}'".format(cmd) - wrap_cmd = wrap_cmd.format(cmd) launch_info.multiprog_content = f"0-{launch_info.n_jobsteps - 1} {wrap_cmd}\n" return launch_info diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index d6dc2ca..364726e 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -289,9 +289,10 @@ class MasterWorker(worker_base.Worker): mode=self.wandb_config.mode, entity=self.wandb_config.entity, project=self.wandb_config.project or constants.experiment_name(), - name=self.wandb_config.name or constants.trial_name(), + name=self.wandb_config.name or f"{constants.trial_name()}_train", job_type=self.wandb_config.job_type, - group=self.wandb_config.group, + group=self.wandb_config.group + or f"{constants.experiment_name()}_{constants.trial_name()}", notes=self.wandb_config.notes, tags=self.wandb_config.tags, config=self.wandb_config.config, @@ -299,6 +300,7 @@ class MasterWorker(worker_base.Worker): constants.LOG_ROOT, constants.experiment_name(), constants.trial_name() ), force=True, + id=f"{constants.experiment_name()}_{constants.trial_name()}_train", resume="allow", settings=wandb.Settings(start_method="fork"), )