This commit is contained in:
bowei.fw 2025-03-28 11:24:37 +08:00
parent 86d08db879
commit 68d8e860a3
3 changed files with 29 additions and 12 deletions

View File

@ -2,9 +2,6 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
# Initialize preset config before all submodules.
from .base import prologue
# Re-import these classes for clear documentation,
# otherwise the name will have a long prefix like
# realhf.api.quickstart.model.ModelTrainEvalConfig.
@ -31,6 +28,9 @@ from .api.quickstart.model import (
OptimizerConfig,
ParallelismConfig,
)
# Initialize preset config before all submodules.
from .base import prologue
from .experiments.common.common import CommonExperimentConfig, ExperimentSaveEvalControl
from .experiments.common.ppo_math_exp import PPOHyperparameters, PPOMATHConfig
from .experiments.common.sft_exp import SFTConfig

View File

@ -5,18 +5,24 @@
import argparse
import datetime
import getpass
import os
import pathlib
import re
import sys
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from realhf.base.prologue import PROLOGUE_FLAG_NAME, PROLOGUE_FLAG_VAR_NAME, PROLOGUE_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
from realhf.base.cluster import spec as cluster_spec
from realhf.base.importing import import_module
from realhf.base.prologue import (
PROLOGUE_EXTERNAL_CONFIG_NAME,
PROLOGUE_FLAG_NAME,
PROLOGUE_FLAG_VAR_NAME,
get_experiment_name,
get_trial_name,
)
# NOTE: Register all implemented experiments inside ReaL.
import_module(
@ -53,7 +59,9 @@ def main():
trial_name = ""
if args[PROLOGUE_FLAG_VAR_NAME]:
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PROLOGUE_FLAG_VAR_NAME])
config_dir, experiment_name, trial_name = prepare_hydra_config(
args["cmd"], args[PROLOGUE_FLAG_VAR_NAME]
)
sys.argv.remove(PROLOGUE_FLAG_NAME)
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
sys.argv += [f"--config-path", f"{config_dir}"]
@ -61,7 +69,9 @@ def main():
experiment_name = get_experiment_name()
trial_name = get_trial_name()
launch_hydra_task(args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]])
launch_hydra_task(
args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]]
)
def prepare_hydra_config(name: str, prologue_path: str):
@ -76,7 +86,10 @@ def prepare_hydra_config(name: str, prologue_path: str):
return (config_dir, experiment_name, trial_name)
def launch_hydra_task(name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction):
def launch_hydra_task(
name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction
):
# Disable hydra logging.
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
sys.argv += ["hydra/job_logging=disabled"]

View File

@ -2,17 +2,19 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import getpass
import sys
import datetime
import os
import getpass
import json
import os
import sys
from omegaconf import DictConfig, OmegaConf
PROLOGUE_FLAG_NAME = "--config"
PROLOGUE_FLAG_VAR_NAME = "config"
PROLOGUE_EXTERNAL_CONFIG_NAME = "external_configs"
def global_init():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(PROLOGUE_FLAG_NAME)
@ -53,6 +55,7 @@ def global_init():
json.dump(cluster_spec, f)
os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path
def get_experiment_name(default_name: str = ""):
if any("experiment_name=" in x for x in sys.argv):
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
@ -60,13 +63,14 @@ def get_experiment_name(default_name: str = ""):
)[1]
else:
experiment_name = default_name
if experiment_name == "":
if experiment_name == "":
experiment_name = f"quickstart-experiment"
if "_" in experiment_name:
raise RuntimeError("experiment_name should not contain `_`.")
return experiment_name
def get_trial_name(default_name: str = ""):
if any("trial_name=" in x for x in sys.argv):
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]