mirror of https://github.com/inclusionAI/AReaL
PullRequest: 11 支持训练时自动拉起evaluate任务
Merge branch mzy/auto-eval of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/11 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * test * move evaluator to main process * . * clear codes * add docstring * . * separate wandb groups * .
This commit is contained in:
parent
deccd47a22
commit
e9bf229581
|
@ -5,6 +5,7 @@ import subprocess
|
|||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import wandb
|
||||
from rm_maj_eval import group_pred
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -29,6 +30,7 @@ def parse_args():
|
|||
parser.add_argument("--overwrite", action="store_true")
|
||||
parser.add_argument("--evaluate_train", action="store_true")
|
||||
parser.add_argument("--max_gen_tokens", default=32768, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.output_path is None:
|
||||
args.output_path = args.model_path
|
||||
|
@ -145,7 +147,7 @@ def process_single_data_name(args, data_name, base_dir, tokenizer):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
print(f"Evaluation output to {args.output_path}")
|
||||
assert args.num_sample_nodes * args.samples_per_node >= args.n_sampling
|
||||
|
||||
eval_dir = (
|
||||
|
@ -155,6 +157,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
base_dir = os.path.join(args.output_path, eval_dir)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
result_path = os.path.join(base_dir, f"aggregate_parallel_{args.prompt_type}.json")
|
||||
|
||||
|
@ -228,10 +231,10 @@ if __name__ == "__main__":
|
|||
from prettytable import PrettyTable
|
||||
|
||||
table = PrettyTable()
|
||||
filed_names = ["dataset"] + list(all_results[args.data_names[0]].keys())
|
||||
table.field_names = filed_names
|
||||
field_names = ["dataset"] + list(all_results[args.data_names[0]].keys())
|
||||
table.field_names = field_names
|
||||
for k, v in all_results.items():
|
||||
table.add_row([k, *[round(v[x], 1) for x in filed_names[1:]]])
|
||||
table.add_row([k, *[round(v[x], 1) for x in field_names[1:]]])
|
||||
|
||||
print(table)
|
||||
except:
|
||||
|
|
|
@ -137,7 +137,7 @@ def generate_in_parallel(requests, model_args, sampling_params, data_parallel_si
|
|||
def run_inference_one_model(
|
||||
model_args: dict, sampling_params, requests, cuda_visisble_devices
|
||||
):
|
||||
os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
|
||||
os.environ["VLLM_LOGGING_LEVEL"] = "INFO"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
[str(x) for x in cuda_visisble_devices]
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ vllm
|
|||
tqdm
|
||||
datasets
|
||||
torch
|
||||
transformers
|
||||
transformers==4.47.0
|
||||
python_dateutil
|
||||
flash_attn
|
||||
|
||||
|
@ -12,4 +12,6 @@ sympy==1.12
|
|||
antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy.
|
||||
word2number
|
||||
Pebble
|
||||
timeout-decorator
|
||||
prettytable
|
||||
timeout-decorator
|
||||
wandb
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
# Users should run this script under AReaL directory
|
||||
|
||||
/usr/bin/python3 -m pip install -e evaluation/latex2sympy/
|
||||
/usr/bin/python3 -m pip install -r evaluation/requirements.txt
|
||||
|
||||
cd evaluation && /usr/bin/python3 eval_and_aggregate.py --model_path $1 --output_path $2 --data_names $3 --max_gen_tokens $4 --prompt_type $5
|
|
@ -234,6 +234,36 @@ class ExperimentScheduling:
|
|||
controller_image: str = _LLM_CPU_IMAGE
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutomaticEvaluator:
|
||||
"""Configuration for automatic evaluation.
|
||||
:param data_names: Dataset for evaluation seperated by comma. Currently support datasets stored under ./evaluation/data,
|
||||
including "aime24", "amc23" and "math_500". For example, if "aime24" and "amc23" are required for evaluation,
|
||||
this field should be set to "aime24,amc23".
|
||||
:type data_names: str
|
||||
:param max_gen_tokens: Maximum number of tokens to be generated in evaluation.
|
||||
:type max_gen_tokens: int
|
||||
:param max_concurrent_jobs: Maximum number of concurrent evaluation jobs to submit. If number of existing jobs is equal to
|
||||
`max_concurrent_jobs` and a new checkpoint is saved, the evaluation job will wait until former jobs complete.
|
||||
:type max_concurrent_jobs: int
|
||||
:param eval_job_image: Container image used to launch evaluation job. If set to None, evaluation jobs will use
|
||||
GPU image for training.
|
||||
:type eval_job_image: Optional[str]
|
||||
:param initial_checkpoint_path: Initial checkpoint path to evaluate. If specified, this initial checkpoint will be evaluated,
|
||||
results will be stored as global_step = 0.
|
||||
:type initial_checkpoint_path: Optional[str]
|
||||
:param prompt_type: Prompt format used in evaluation.
|
||||
:type prompt_type: str
|
||||
"""
|
||||
|
||||
data_names: str = "aime24"
|
||||
max_gen_tokens: int = 32768
|
||||
max_concurrent_jobs: int = 3
|
||||
eval_job_image: Optional[str] = None
|
||||
initial_checkpoint_path: Optional[str] = None
|
||||
prompt_type: str = "deepscaler"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WandBConfig:
|
||||
mode: str = "disabled"
|
||||
|
@ -256,6 +286,11 @@ class ExperimentConfig:
|
|||
model_worker: List[ModelWorker] = dataclasses.field(default_factory=list)
|
||||
# master_worker will be set automatically
|
||||
master_worker: Optional[List[MasterWorker]] = None
|
||||
# automatic evaluation
|
||||
auto_eval: bool = False
|
||||
evaluator: AutomaticEvaluator = dataclasses.field(
|
||||
default_factory=AutomaticEvaluator
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.master_worker = [
|
||||
|
|
|
@ -18,6 +18,7 @@ import realhf.base.recover as recover
|
|||
import realhf.scheduler.client as sched_client
|
||||
import realhf.system as system
|
||||
from realhf.scheduler.client import JobException, JobState
|
||||
from realhf.scheduler.evaluator import AutomaticEvaluator
|
||||
|
||||
logger = logging.getLogger("main", "system")
|
||||
|
||||
|
@ -85,10 +86,17 @@ def main_start(args, recover_count: int = 0):
|
|||
# Run initial_setup to go through all sanity checks.
|
||||
try:
|
||||
exp_cfg = experiment.initial_setup()
|
||||
assert isinstance(exp_cfg, config_package.ExperimentConfig)
|
||||
exp_cfg.lazy_init()
|
||||
except Exception as e:
|
||||
raise RuntimeError("Experiment initial setup failed.") from e
|
||||
|
||||
evaluator = (
|
||||
AutomaticEvaluator(exp_cfg.evaluator, exp_cfg.wandb)
|
||||
if exp_cfg.auto_eval
|
||||
else None
|
||||
)
|
||||
|
||||
if args.mode == "local":
|
||||
assert (
|
||||
args.recover_mode == "disabled"
|
||||
|
@ -167,6 +175,7 @@ def main_start(args, recover_count: int = 0):
|
|||
expr_name=expr_name,
|
||||
trial_name=trial_name,
|
||||
schedule_strategy=args.schedule_strategy,
|
||||
evaluator=evaluator,
|
||||
)
|
||||
|
||||
setup = experiment.scheduling_setup()
|
||||
|
|
|
@ -105,7 +105,6 @@ BASE_ENVIRONS = {
|
|||
# "TORCH_SHOW_CPP_STACKTRACES": "1",
|
||||
# "RAY_DEDUP_LOGS": "0", # disable ray log deduplication
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"PYTHONUSERBASE": "/nonsense", # a random PYTHONUSERBASE to avoid local user site-packages interference
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
|
|
|
@ -29,6 +29,7 @@ from realhf.api.core.config import (
|
|||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
|
||||
from realhf.api.core.system_api import (
|
||||
AutomaticEvaluator,
|
||||
Experiment,
|
||||
ExperimentConfig,
|
||||
ExperimentSaveEvalControl,
|
||||
|
@ -183,6 +184,12 @@ class CommonExperimentConfig(Experiment):
|
|||
torch.cuda.empty_cache() before each RPC in model worker
|
||||
If enabled, there will be a ~0.1s overhead per RPC.
|
||||
:type torch_cache_mysophobia: bool
|
||||
:param auto_eval: Whether to automatic evaluation in training. When enabled, an evaluation
|
||||
job is submitted whenever a checkpoint is saved, and the result will be logged on disk and
|
||||
on wandb if wandb is active.
|
||||
:type auto_eval: bool
|
||||
:param auto_eval_config: Configuration for automatic evaluation.
|
||||
:type auto_eval_config: AutomaticEvaluator
|
||||
:param cpus_per_master_worker: The number of CPUs for each master worker.
|
||||
:param mem_per_master_worker: The size of memory for each master worker, measured in MB.
|
||||
:param cpus_per_model_worker: The number of CPUs for each model worker.
|
||||
|
@ -214,6 +221,12 @@ class CommonExperimentConfig(Experiment):
|
|||
default_factory=ExperimentSaveEvalControl
|
||||
)
|
||||
torch_cache_mysophobia: bool = True
|
||||
# Options for automatic evaluation
|
||||
auto_eval: bool = False
|
||||
auto_eval_config: AutomaticEvaluator = dataclasses.field(
|
||||
default_factory=AutomaticEvaluator
|
||||
)
|
||||
# Options for worker resources
|
||||
cpus_per_master_worker: int = 4
|
||||
mem_per_master_worker: int = 20000
|
||||
cpus_per_model_worker: int = 4
|
||||
|
@ -716,6 +729,8 @@ class CommonExperimentConfig(Experiment):
|
|||
wandb=self.wandb,
|
||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
||||
model_worker=model_worker,
|
||||
auto_eval=self.auto_eval,
|
||||
evaluator=self.auto_eval_config,
|
||||
)
|
||||
|
||||
def __check_legal_allocation_options(self):
|
||||
|
|
|
@ -544,12 +544,15 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
######### The main difference from normal PPO #########
|
||||
|
||||
model_worker = self._get_model_worker_configs(rpc_allocs)
|
||||
self.auto_eval_config.initial_checkpoint_path = self.actor.path
|
||||
|
||||
return ExperimentConfig(
|
||||
exp_ctrl=self.exp_ctrl,
|
||||
wandb=self.wandb,
|
||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
||||
model_worker=model_worker,
|
||||
auto_eval=self.auto_eval,
|
||||
evaluator=self.auto_eval_config,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -162,7 +162,8 @@ def make(mode, expr_name, trial_name, **kwargs) -> SchedulerClient:
|
|||
from realhf.scheduler.slurm.client import SlurmSchedulerClient
|
||||
|
||||
schedule_strategy = kwargs.get("schedule_strategy", "empty_first")
|
||||
return SlurmSchedulerClient(expr_name, trial_name, schedule_strategy)
|
||||
evaluator = kwargs.get("evaluator", None)
|
||||
return SlurmSchedulerClient(expr_name, trial_name, schedule_strategy, evaluator)
|
||||
elif mode == "local":
|
||||
from realhf.scheduler.local.client import LocalSchedulerClient
|
||||
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import wandb
|
||||
|
||||
import realhf.api.core.system_api as config_pkg
|
||||
from realhf.base import cluster, constants, logging
|
||||
|
||||
logger = logging.getLogger("AutomaticEvaluator", "colored")
|
||||
|
||||
|
||||
class AutomaticEvaluator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: config_pkg.AutomaticEvaluator,
|
||||
wandb_config: config_pkg.WandBConfig,
|
||||
):
|
||||
self.__running_processes: Dict[int, subprocess.Popen] = {}
|
||||
self.__start_time = {}
|
||||
self.__done_steps = []
|
||||
self.__wandb_log_steps = []
|
||||
self.__pending_ckpts = {}
|
||||
self.__image = config.eval_job_image or cluster.spec.gpu_image
|
||||
self.__data_names = config.data_names
|
||||
self.__max_gen_tokens = config.max_gen_tokens
|
||||
self.__max_concurrent_jobs = config.max_concurrent_jobs
|
||||
self.__prompt_type = config.prompt_type
|
||||
self.__wandb_config = wandb_config
|
||||
self.__config = config
|
||||
self.__wandb_inited = False
|
||||
|
||||
# Check evaluated checkpoints by logs
|
||||
former_output_dir = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"eval_output",
|
||||
)
|
||||
if os.path.exists(former_output_dir):
|
||||
for log_dir in os.listdir(former_output_dir):
|
||||
match = re.match(r"globalstep(\d+)", log_dir)
|
||||
if not match:
|
||||
continue
|
||||
global_step = int(match.groups()[0])
|
||||
self.__done_steps.append(global_step)
|
||||
|
||||
logger.info(
|
||||
f"Initializing AutomaticEvaluator: \n"
|
||||
f"eval_job_image: {config.eval_job_image}\n"
|
||||
f"data_names: {config.data_names}\n"
|
||||
f"max_gen_tokens: {config.max_gen_tokens}\n"
|
||||
f"max_concurrent_jobs: {config.max_concurrent_jobs}\n"
|
||||
f"Existing eval outputs for global steps: {self.__done_steps}"
|
||||
)
|
||||
if self.__config.initial_checkpoint_path and 0 not in self.__done_steps:
|
||||
self.__pending_ckpts[0] = self.__config.initial_checkpoint_path
|
||||
|
||||
if not cluster.spec.cluster_type == "slurm":
|
||||
raise NotImplementedError(
|
||||
"Currently only support automatic evaluation for slurm"
|
||||
)
|
||||
|
||||
def __lazy_wandb_init(self):
|
||||
# Initializing wandb for evaluator
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.__wandb_config.mode,
|
||||
entity=self.__wandb_config.entity,
|
||||
project=self.__wandb_config.project or constants.experiment_name(),
|
||||
name=self.__wandb_config.name or f"{constants.trial_name()}_eval",
|
||||
job_type=self.__wandb_config.job_type,
|
||||
group=self.__wandb_config.group
|
||||
or f"{constants.experiment_name()}_{constants.trial_name()}",
|
||||
notes=self.__wandb_config.notes,
|
||||
tags=self.__wandb_config.tags,
|
||||
config=self.__wandb_config.config,
|
||||
dir=os.path.join(
|
||||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
||||
),
|
||||
force=True,
|
||||
id=f"{constants.experiment_name()}_{constants.trial_name()}_eval",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
||||
def __check_new_ckpts(self):
|
||||
save_path = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"actor",
|
||||
)
|
||||
if not os.path.exists(save_path):
|
||||
return
|
||||
for ckpt_dir in os.listdir(save_path):
|
||||
match = re.match(r"epoch(\d+)epochstep(\d+)globalstep(\d+)", ckpt_dir)
|
||||
if not match:
|
||||
continue
|
||||
_, _, global_step = map(int, match.groups())
|
||||
if not global_step in (
|
||||
list(self.__running_processes.keys())
|
||||
+ list(self.__pending_ckpts.keys())
|
||||
+ self.__done_steps
|
||||
):
|
||||
abs_ckpt_dir = os.path.join(save_path, ckpt_dir)
|
||||
logger.info(
|
||||
f"Found new checkpoint (globalstep{global_step}) at {abs_ckpt_dir}"
|
||||
)
|
||||
self.__pending_ckpts[global_step] = os.path.join(
|
||||
save_path, abs_ckpt_dir
|
||||
)
|
||||
|
||||
def __check_and_maybe_submit_jobs(self):
|
||||
for global_step, process in self.__running_processes.items():
|
||||
result = process.poll()
|
||||
if not result is None:
|
||||
self.__done_steps.append(global_step)
|
||||
start_time = self.__start_time[global_step]
|
||||
logger.info(
|
||||
f"Evaluation of checkpoint (globalstep{global_step}) is done, returncode={process.returncode}, "
|
||||
f"time passed {time.perf_counter() - start_time:.3f} s."
|
||||
)
|
||||
|
||||
for done in self.__done_steps:
|
||||
if done in self.__running_processes:
|
||||
self.__running_processes.pop(done)
|
||||
self.__start_time.pop(done)
|
||||
|
||||
submitted = []
|
||||
# Jobs should be submitted by the order of global steps
|
||||
ordered_steps = sorted(self.__pending_ckpts.keys())
|
||||
for global_step in ordered_steps:
|
||||
ckpt_path = self.__pending_ckpts[global_step]
|
||||
if len(self.__running_processes) >= self.__max_concurrent_jobs:
|
||||
return
|
||||
self.__submit_one(global_step, ckpt_path)
|
||||
submitted.append(global_step)
|
||||
|
||||
for global_step in submitted:
|
||||
self.__pending_ckpts.pop(global_step)
|
||||
|
||||
def __submit_one(self, global_step, ckpt_path):
|
||||
output_path = self.eval_output_path(global_step)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
log_file = open(os.path.join(output_path, "output.log"), "w")
|
||||
if cluster.spec.cluster_type == "slurm":
|
||||
cmd = self.slurm_eval_cmd(global_step, ckpt_path)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"AutomaticEvaluator does only support slurm cluster."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Submitting evaluation job of checkpoint at {ckpt_path} (globalstep{global_step}), "
|
||||
f"command: {cmd}"
|
||||
)
|
||||
self.__running_processes[global_step] = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=log_file,
|
||||
shell=True,
|
||||
)
|
||||
self.__start_time[global_step] = time.perf_counter()
|
||||
|
||||
def __maybe_parse_and_log_to_wandb(self):
|
||||
# Note that after recover, all previous done steps will be
|
||||
# logged again in case some data points are missing.
|
||||
# If the data point is already logged, wandb will raise
|
||||
# a warning.
|
||||
to_log = list(
|
||||
filter(
|
||||
lambda x: x not in self.__wandb_log_steps,
|
||||
(
|
||||
self.__done_steps
|
||||
+ list(self.__running_processes.keys())
|
||||
+ list(self.__pending_ckpts.keys())
|
||||
),
|
||||
)
|
||||
)
|
||||
while to_log:
|
||||
# The wandb should always log the minimal global step
|
||||
# whose eval job has been submitted but not logged.
|
||||
# If this minimal step is not logged, other steps should wait.
|
||||
global_step = min(to_log)
|
||||
result_path = os.path.join(
|
||||
self.eval_output_path(global_step),
|
||||
f"math_eval_{self.__max_gen_tokens}",
|
||||
f"aggregate_parallel_{self.__prompt_type}.json",
|
||||
)
|
||||
if not os.path.exists(result_path):
|
||||
break
|
||||
|
||||
if not self.__wandb_inited:
|
||||
self.__lazy_wandb_init()
|
||||
self.__wandb_inited = True
|
||||
|
||||
try:
|
||||
with open(result_path, "r") as fp:
|
||||
data = json.load(fp)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"JSON decoding for eval result in {result_path} failed")
|
||||
continue
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
f"{result_path} not found, but the eval job is done. "
|
||||
"Maybe the eval job abnormally exited and did not output the result."
|
||||
)
|
||||
continue
|
||||
wandb_data = {}
|
||||
for data_name, d in data.items():
|
||||
for k, v in d.items():
|
||||
wandb_data[f"{data_name}_{k}"] = v
|
||||
wandb.log(wandb_data, step=global_step)
|
||||
logger.info(f"Logging eval result {wandb_data} to step {global_step}")
|
||||
self.__wandb_log_steps.append(global_step)
|
||||
to_log.remove(global_step)
|
||||
|
||||
def step(self):
|
||||
self.__check_new_ckpts()
|
||||
self.__check_and_maybe_submit_jobs()
|
||||
self.__maybe_parse_and_log_to_wandb()
|
||||
|
||||
def slurm_eval_cmd(self, global_step, ckpt_path):
|
||||
slurm_job_name = f"{constants.experiment_name()}_{constants.trial_name()}:eval_globalstep{global_step}"
|
||||
cmd = (
|
||||
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
||||
f"singularity exec --nv --pid --writable-tmpfs --bind /storage:/storage "
|
||||
f"{self.__image} "
|
||||
f"bash ./evaluation/sh/install_deps_and_eval.sh {ckpt_path} {self.eval_output_path(global_step)} "
|
||||
f"{self.__data_names} {self.__max_gen_tokens} {self.__prompt_type}"
|
||||
)
|
||||
return cmd
|
||||
|
||||
def eval_output_path(self, global_step):
|
||||
return os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"eval_output",
|
||||
f"globalstep{global_step}",
|
||||
)
|
|
@ -14,6 +14,7 @@ import realhf.base.logging as logging
|
|||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
|
||||
from realhf.scheduler.evaluator import AutomaticEvaluator
|
||||
from realhf.scheduler.slurm.utils import (
|
||||
SlurmLaunchInfo,
|
||||
SlurmResource,
|
||||
|
@ -25,12 +26,19 @@ logger = logging.getLogger("Slurm-scheduler")
|
|||
|
||||
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
|
||||
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
|
||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
class SlurmSchedulerClient(SchedulerClient):
|
||||
"""Uses Slurm (https://slurm.schedmd.com/overview.html)."""
|
||||
|
||||
def __init__(self, expr_name, trial_name, schedule_strategy):
|
||||
def __init__(
|
||||
self,
|
||||
expr_name: str,
|
||||
trial_name: str,
|
||||
schedule_strategy: str,
|
||||
evaluator: Optional[AutomaticEvaluator],
|
||||
):
|
||||
super().__init__(expr_name, trial_name)
|
||||
|
||||
self.__schedule_strategy = schedule_strategy
|
||||
|
@ -40,6 +48,7 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
|
||||
self.__submission_counter = defaultdict(int)
|
||||
self.__wprocs_counter = defaultdict(int)
|
||||
self.__evaluator = evaluator
|
||||
|
||||
def submit(self, worker_type, cmd, **kwargs):
|
||||
self.submit_array(worker_type, cmd, count=1, **kwargs)
|
||||
|
@ -245,6 +254,8 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||
if self.__evaluator is not None:
|
||||
self.__evaluator.step()
|
||||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
|
||||
|
@ -276,7 +287,7 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
left.remove(job_slurm_name)
|
||||
if update:
|
||||
self.__committed_jobs.pop(job_slurm_name)
|
||||
time.sleep(2)
|
||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
def __update_all(self):
|
||||
states = []
|
||||
|
|
|
@ -450,7 +450,7 @@ class RayController:
|
|||
|
||||
It uses the basic Controller to configure workers. Besides, it
|
||||
launchs all remote workers using Ray, instead of submitting them to
|
||||
the scheduelr.
|
||||
the scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, trial_name):
|
||||
|
|
|
@ -963,9 +963,10 @@ class MasterWorker(worker_base.Worker):
|
|||
mode=self.wandb_config.mode,
|
||||
entity=self.wandb_config.entity,
|
||||
project=self.wandb_config.project or constants.experiment_name(),
|
||||
name=self.wandb_config.name or constants.trial_name(),
|
||||
name=self.wandb_config.name or f"{constants.trial_name()}_train",
|
||||
job_type=self.wandb_config.job_type,
|
||||
group=self.wandb_config.group,
|
||||
group=self.wandb_config.group
|
||||
or f"{constants.experiment_name()}_{constants.trial_name()}",
|
||||
notes=self.wandb_config.notes,
|
||||
tags=self.wandb_config.tags,
|
||||
config=self.wandb_config.config,
|
||||
|
@ -973,6 +974,7 @@ class MasterWorker(worker_base.Worker):
|
|||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
||||
),
|
||||
force=True,
|
||||
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue