mirror of https://github.com/inclusionAI/AReaL
PullRequest: 63 simplify startup command
Merge branch simplify-startup of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/63?tab=diff Signed-off-by: 博惟 <bowei.fw@antgroup.com>
This commit is contained in:
commit
71429c9655
|
@ -0,0 +1,86 @@
|
|||
experiment_name: areal-1.5B-distill-gpus-8
|
||||
trial_name: 1024x8
|
||||
mode: ray
|
||||
wandb:
|
||||
mode: disabled
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
allocation_mode: 'sglang.d4p1m1+d4p1m1'
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cache_clear_freq: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 10
|
||||
save_freq_epochs: 1
|
||||
ckpt_freq_secs: 600
|
||||
torch_cache_mysophobia: true
|
||||
actor:
|
||||
type:
|
||||
_class: qwen2
|
||||
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
|
||||
optimizer:
|
||||
lr: 1.0e-05
|
||||
lr_scheduler_type: constant
|
||||
initial_loss_scale: 262144.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
sglang:
|
||||
disable_radix_cache: true
|
||||
context_length: 18432
|
||||
mem_fraction_static: 0.8
|
||||
max_running_requests: 128
|
||||
critic:
|
||||
type:
|
||||
_class: qwen2
|
||||
is_critic: true
|
||||
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
|
||||
init_critic_from_actor: true
|
||||
ref:
|
||||
type:
|
||||
_class: qwen2
|
||||
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
critic_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
actor_gen:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
critic_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
ref_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 19456
|
||||
dataset:
|
||||
path: '/storage/datasets/prompts_for_r1_distilled_0319.jsonl'
|
||||
max_prompt_len: 2048
|
||||
train_bs_n_seqs: 1024
|
||||
ppo:
|
||||
gen:
|
||||
max_new_tokens: 16384
|
||||
min_new_tokens: 0
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
kl_ctl: 0.001
|
||||
discount: 1.0
|
||||
value_eps_clip: 0.2
|
||||
disable_value: true
|
||||
reward_output_scaling: 5.0
|
||||
reward_output_bias: 0.0
|
||||
adv_norm: true
|
||||
value_norm: true
|
||||
group_size: 8
|
||||
group_adv_norm: false
|
||||
external_configs:
|
||||
cluster_config:
|
||||
fileroot: "/storage/ray/experiments"
|
||||
envs:
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD: "1"
|
|
@ -28,6 +28,9 @@ from .api.quickstart.model import (
|
|||
OptimizerConfig,
|
||||
ParallelismConfig,
|
||||
)
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue
|
||||
from .experiments.common.common import CommonExperimentConfig, ExperimentSaveEvalControl
|
||||
from .experiments.common.ppo_math_exp import PPOHyperparameters, PPOMATHConfig
|
||||
from .experiments.common.sft_exp import SFTConfig
|
||||
|
|
|
@ -5,15 +5,24 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import getpass
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.importing import import_module
|
||||
from realhf.base.prologue import (
|
||||
PROLOGUE_EXTERNAL_CONFIG_NAME,
|
||||
PROLOGUE_FLAG_NAME,
|
||||
PROLOGUE_FLAG_VAR_NAME,
|
||||
get_experiment_name,
|
||||
get_trial_name,
|
||||
)
|
||||
|
||||
# NOTE: Register all implemented experiments inside ReaL.
|
||||
import_module(
|
||||
|
@ -34,46 +43,65 @@ def main():
|
|||
action="store_true",
|
||||
help="Show all legal CLI arguments for this experiment.",
|
||||
)
|
||||
subparser.add_argument(
|
||||
PROLOGUE_FLAG_NAME,
|
||||
type=str,
|
||||
help="Set config (*.yaml) for this experiment.",
|
||||
)
|
||||
subparser.set_defaults(func=v)
|
||||
args = parser.parse_known_args()[0]
|
||||
if args.show_args:
|
||||
args = vars(parser.parse_known_args()[0])
|
||||
if args["show_args"]:
|
||||
sys.argv = [sys.argv[0], "--help"]
|
||||
QUICKSTART_FN[args.cmd]()
|
||||
QUICKSTART_FN[args["cmd"]]()
|
||||
return
|
||||
|
||||
launch_hydra_task(args.cmd, QUICKSTART_FN[args.cmd])
|
||||
experiment_name = ""
|
||||
trial_name = ""
|
||||
|
||||
if args[PROLOGUE_FLAG_VAR_NAME]:
|
||||
config_dir, experiment_name, trial_name = prepare_hydra_config(
|
||||
args["cmd"], args[PROLOGUE_FLAG_VAR_NAME]
|
||||
)
|
||||
sys.argv.remove(PROLOGUE_FLAG_NAME)
|
||||
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
|
||||
sys.argv += [f"--config-path", f"{config_dir}"]
|
||||
else:
|
||||
experiment_name = get_experiment_name()
|
||||
trial_name = get_trial_name()
|
||||
|
||||
launch_hydra_task(
|
||||
args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]]
|
||||
)
|
||||
|
||||
|
||||
def launch_hydra_task(name: str, func: hydra.TaskFunction):
|
||||
def prepare_hydra_config(name: str, prologue_path: str):
|
||||
config = OmegaConf.load(prologue_path)
|
||||
experiment_name = get_experiment_name(config.get("experiment_name"))
|
||||
trial_name = get_trial_name(config.get("trial_name"))
|
||||
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
||||
|
||||
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME)
|
||||
with open(f"{config_dir}/{name}.yaml", "w") as f:
|
||||
f.write(OmegaConf.to_yaml(config))
|
||||
|
||||
return (config_dir, experiment_name, trial_name)
|
||||
|
||||
|
||||
def launch_hydra_task(
|
||||
name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction
|
||||
):
|
||||
# Disable hydra logging.
|
||||
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
|
||||
sys.argv += ["hydra/job_logging=disabled"]
|
||||
|
||||
if any("experiment_name=" in x for x in sys.argv):
|
||||
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
|
||||
"="
|
||||
)[1]
|
||||
if "_" in experiment_name:
|
||||
raise RuntimeError("experiment_name should not contain `_`.")
|
||||
else:
|
||||
experiment_name = f"quickstart-{name}"
|
||||
print(f"Experiment name not manually set. Default to {experiment_name}.")
|
||||
sys.argv += [f"experiment_name={experiment_name}"]
|
||||
|
||||
if (
|
||||
"--multirun" in sys.argv
|
||||
or "hydra.mode=MULTIRUN" in sys.argv
|
||||
or "-m" in sys.argv
|
||||
):
|
||||
raise NotImplementedError("Hydra multi-run is not supported.")
|
||||
# non-multirun mode, add trial_name and hydra run dir
|
||||
if any("trial_name=" in x for x in sys.argv):
|
||||
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]
|
||||
else:
|
||||
trial_name = f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
sys.argv += [f"trial_name={trial_name}"]
|
||||
if "_" in trial_name:
|
||||
raise RuntimeError("trial_name should not contain `_`.")
|
||||
|
||||
# non-multirun mode, add hydra run dir
|
||||
sys.argv += [
|
||||
f"hydra.run.dir={cluster_spec.fileroot}/logs/{getpass.getuser()}/"
|
||||
f"{experiment_name}/{trial_name}/hydra-outputs/"
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
PROLOGUE_FLAG_NAME = "--config"
|
||||
PROLOGUE_FLAG_VAR_NAME = "config"
|
||||
PROLOGUE_EXTERNAL_CONFIG_NAME = "external_configs"
|
||||
|
||||
|
||||
def global_init():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument(PROLOGUE_FLAG_NAME)
|
||||
args = vars(parser.parse_known_args()[0])
|
||||
if args[PROLOGUE_FLAG_VAR_NAME] is None:
|
||||
return
|
||||
prologue_path = args[PROLOGUE_FLAG_VAR_NAME]
|
||||
|
||||
config = OmegaConf.load(prologue_path)
|
||||
external_configs = config.get(PROLOGUE_EXTERNAL_CONFIG_NAME)
|
||||
|
||||
if external_configs is None:
|
||||
return
|
||||
|
||||
# add externel envs.
|
||||
if external_configs.envs is not None:
|
||||
for key, value in external_configs.envs.items():
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
# resolve config path for cluster spec.
|
||||
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
|
||||
if cluster_spec_path == "":
|
||||
if external_configs.cluster_config is not None:
|
||||
fileroot = external_configs.cluster_config.fileroot
|
||||
if fileroot is not None and fileroot != "":
|
||||
experiment_name = get_experiment_name(config.get("experiment_name"))
|
||||
trial_name = get_trial_name(config.get("trial_name"))
|
||||
config_dir = f"{fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
cluster_spec_path = f"{config_dir}/cluster_config.json"
|
||||
cluster_spec = OmegaConf.to_container(external_configs.cluster_config)
|
||||
if "cluster_type" not in cluster_spec:
|
||||
cluster_spec["cluster_type"] = config.mode
|
||||
if "cluster_name" not in cluster_spec:
|
||||
cluster_spec["cluster_name"] = f"{config.mode}_cluster"
|
||||
with open(cluster_spec_path, "w") as f:
|
||||
json.dump(cluster_spec, f)
|
||||
os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path
|
||||
|
||||
|
||||
def get_experiment_name(default_name: str = ""):
|
||||
if any("experiment_name=" in x for x in sys.argv):
|
||||
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
|
||||
"="
|
||||
)[1]
|
||||
else:
|
||||
experiment_name = default_name
|
||||
if experiment_name == "":
|
||||
experiment_name = f"quickstart-experiment"
|
||||
|
||||
if "_" in experiment_name:
|
||||
raise RuntimeError("experiment_name should not contain `_`.")
|
||||
return experiment_name
|
||||
|
||||
|
||||
def get_trial_name(default_name: str = ""):
|
||||
if any("trial_name=" in x for x in sys.argv):
|
||||
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]
|
||||
else:
|
||||
trial_name = default_name
|
||||
if trial_name == "":
|
||||
trial_name = f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
|
||||
if "_" in trial_name:
|
||||
raise RuntimeError("trial_name should not contain `_`.")
|
||||
return trial_name
|
||||
|
||||
|
||||
global_init()
|
Loading…
Reference in New Issue