simplify startup command

This commit is contained in:
kira.gw 2025-03-27 20:44:44 +08:00
parent 9f77f96580
commit 62b9bf3f44
4 changed files with 211 additions and 24 deletions

View File

@ -0,0 +1,86 @@
experiment_name: areal-1.5B-distill-gpus-8
trial_name: 1024x8
mode: ray
wandb:
mode: disabled
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d4p1m1+d4p1m1'
n_nodes: 1
n_gpus_per_node: 8
cache_clear_freq: 1
exp_ctrl:
total_train_epochs: 10
save_freq_epochs: 1
ckpt_freq_secs: 600
torch_cache_mysophobia: true
actor:
type:
_class: qwen2
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
optimizer:
lr: 1.0e-05
lr_scheduler_type: constant
initial_loss_scale: 262144.0
loss_scale_window: 5.0
hysteresis: 2
sglang:
disable_radix_cache: true
context_length: 18432
mem_fraction_static: 0.8
max_running_requests: 128
critic:
type:
_class: qwen2
is_critic: true
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
init_critic_from_actor: true
ref:
type:
_class: qwen2
path: '/storage/models/DeepSeek-R1-Distill-Qwen-1.5B'
actor_train:
mb_spec:
max_tokens_per_mb: 19456
critic_train:
mb_spec:
max_tokens_per_mb: 19456
actor_gen:
mb_spec:
max_tokens_per_mb: 19456
critic_inf:
mb_spec:
max_tokens_per_mb: 19456
actor_inf:
mb_spec:
max_tokens_per_mb: 19456
ref_inf:
mb_spec:
max_tokens_per_mb: 19456
dataset:
path: '/storage/datasets/prompts_for_r1_distilled_0319.jsonl'
max_prompt_len: 2048
train_bs_n_seqs: 1024
ppo:
gen:
max_new_tokens: 16384
min_new_tokens: 0
top_p: 1.0
top_k: 1000000
temperature: 1.0
ppo_n_minibatches: 4
kl_ctl: 0.001
discount: 1.0
value_eps_clip: 0.2
disable_value: true
reward_output_scaling: 5.0
reward_output_bias: 0.0
adv_norm: true
value_norm: true
group_size: 8
group_adv_norm: false
external_configs:
cluster_config:
fileroot: "/storage/ray/experiments"
envs:
REAL_GPU_MEMORY_KILL_THRESHOLD: "1"

View File

@ -2,6 +2,9 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
# Initialize preset config before all submodules.
from .base import prologue
# Re-import these classes for clear documentation,
# otherwise the name will have a long prefix like
# realhf.api.quickstart.model.ModelTrainEvalConfig.

View File

@ -8,9 +8,12 @@ import getpass
import pathlib
import re
import sys
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from realhf.base.preset import PRESET_FLAG_NAME, PRESET_FLAG_VAR_NAME, PRESET_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
from realhf.base.cluster import spec as cluster_spec
from realhf.base.importing import import_module
@ -34,46 +37,58 @@ def main():
action="store_true",
help="Show all legal CLI arguments for this experiment.",
)
subparser.add_argument(
PRESET_FLAG_NAME,
type=str,
help="Set preset config (*.yaml) for this experiment.",
)
subparser.set_defaults(func=v)
args = parser.parse_known_args()[0]
if args.show_args:
args = vars(parser.parse_known_args()[0])
if args["show_args"]:
sys.argv = [sys.argv[0], "--help"]
QUICKSTART_FN[args.cmd]()
QUICKSTART_FN[args["cmd"]]()
return
launch_hydra_task(args.cmd, QUICKSTART_FN[args.cmd])
experiment_name = ""
trial_name = ""
if args[PRESET_FLAG_VAR_NAME]:
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PRESET_FLAG_VAR_NAME])
sys.argv.remove(PRESET_FLAG_NAME)
sys.argv.remove(args[PRESET_FLAG_VAR_NAME])
sys.argv += [f"--config-path", f"{config_dir}"]
else:
experiment_name = get_experiment_name()
trial_name = get_trial_name()
launch_hydra_task(args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]])
def launch_hydra_task(name: str, func: hydra.TaskFunction):
def prepare_hydra_config(name: str, preset_path: str):
config = OmegaConf.load(preset_path)
experiment_name = get_experiment_name(config.get("experiment_name"))
trial_name = get_trial_name(config.get("trial_name"))
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
config.pop(PRESET_EXTERNAL_CONFIG_NAME)
with open(f"{config_dir}/{name}.yaml", "w") as f:
f.write(OmegaConf.to_yaml(config))
return (config_dir, experiment_name, trial_name)
def launch_hydra_task(name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction):
# Disable hydra logging.
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
sys.argv += ["hydra/job_logging=disabled"]
if any("experiment_name=" in x for x in sys.argv):
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
"="
)[1]
if "_" in experiment_name:
raise RuntimeError("experiment_name should not contain `_`.")
else:
experiment_name = f"quickstart-{name}"
print(f"Experiment name not manually set. Default to {experiment_name}.")
sys.argv += [f"experiment_name={experiment_name}"]
if (
"--multirun" in sys.argv
or "hydra.mode=MULTIRUN" in sys.argv
or "-m" in sys.argv
):
raise NotImplementedError("Hydra multi-run is not supported.")
# non-multirun mode, add trial_name and hydra run dir
if any("trial_name=" in x for x in sys.argv):
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]
else:
trial_name = f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
sys.argv += [f"trial_name={trial_name}"]
if "_" in trial_name:
raise RuntimeError("trial_name should not contain `_`.")
# non-multirun mode, add hydra run dir
sys.argv += [
f"hydra.run.dir={cluster_spec.fileroot}/logs/{getpass.getuser()}/"
f"{experiment_name}/{trial_name}/hydra-outputs/"

83
realhf/base/prologue.py Normal file
View File

@ -0,0 +1,83 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import getpass
import sys
import datetime
import os
import json
from omegaconf import DictConfig, OmegaConf
PRESET_FLAG_NAME = "--config"
PRESET_FLAG_VAR_NAME = "config"
PRESET_EXTERNAL_CONFIG_NAME = "external_configs"
def global_init():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(PRESET_FLAG_NAME)
args = vars(parser.parse_known_args()[0])
if args[PRESET_FLAG_VAR_NAME] is None:
return
preset_path = args[PRESET_FLAG_VAR_NAME]
config = OmegaConf.load(preset_path)
external_configs = config.get(PRESET_EXTERNAL_CONFIG_NAME)
if external_configs is None:
return
# add externel envs.
if external_configs.envs is not None:
for key, value in external_configs.envs.items():
if key not in os.environ:
os.environ[key] = value
# resolve config path for cluster spec.
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
if cluster_spec_path == "":
if external_configs.cluster_config is not None:
fileroot = external_configs.cluster_config.fileroot
if fileroot is not None and fileroot != "":
experiment_name = get_experiment_name(config.get("experiment_name"))
trial_name = get_trial_name(config.get("trial_name"))
config_dir = f"{fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
os.makedirs(config_dir, exist_ok=True)
cluster_spec_path = f"{config_dir}/cluster_config.json"
cluster_spec = OmegaConf.to_container(external_configs.cluster_config)
if "cluster_type" not in cluster_spec:
cluster_spec["cluster_type"] = config.mode
if "cluster_name" not in cluster_spec:
cluster_spec["cluster_name"] = f"{config.mode}_cluster"
with open(cluster_spec_path, "w") as f:
json.dump(cluster_spec, f)
os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path
def get_experiment_name(default_name: str = ""):
if any("experiment_name=" in x for x in sys.argv):
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
"="
)[1]
else:
experiment_name = default_name
if experiment_name == "":
experiment_name = f"quickstart-experiment"
if "_" in experiment_name:
raise RuntimeError("experiment_name should not contain `_`.")
return experiment_name
def get_trial_name(default_name: str = ""):
if any("trial_name=" in x for x in sys.argv):
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]
else:
trial_name = default_name
if trial_name == "":
trial_name = f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
if "_" in trial_name:
raise RuntimeError("trial_name should not contain `_`.")
return trial_name
global_init()