mirror of https://github.com/inclusionAI/AReaL
format
This commit is contained in:
parent
86d08db879
commit
68d8e860a3
|
@ -2,9 +2,6 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue
|
||||
|
||||
# Re-import these classes for clear documentation,
|
||||
# otherwise the name will have a long prefix like
|
||||
# realhf.api.quickstart.model.ModelTrainEvalConfig.
|
||||
|
@ -31,6 +28,9 @@ from .api.quickstart.model import (
|
|||
OptimizerConfig,
|
||||
ParallelismConfig,
|
||||
)
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue
|
||||
from .experiments.common.common import CommonExperimentConfig, ExperimentSaveEvalControl
|
||||
from .experiments.common.ppo_math_exp import PPOHyperparameters, PPOMATHConfig
|
||||
from .experiments.common.sft_exp import SFTConfig
|
||||
|
|
|
@ -5,18 +5,24 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import getpass
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from realhf.base.prologue import PROLOGUE_FLAG_NAME, PROLOGUE_FLAG_VAR_NAME, PROLOGUE_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.importing import import_module
|
||||
from realhf.base.prologue import (
|
||||
PROLOGUE_EXTERNAL_CONFIG_NAME,
|
||||
PROLOGUE_FLAG_NAME,
|
||||
PROLOGUE_FLAG_VAR_NAME,
|
||||
get_experiment_name,
|
||||
get_trial_name,
|
||||
)
|
||||
|
||||
# NOTE: Register all implemented experiments inside ReaL.
|
||||
import_module(
|
||||
|
@ -53,7 +59,9 @@ def main():
|
|||
trial_name = ""
|
||||
|
||||
if args[PROLOGUE_FLAG_VAR_NAME]:
|
||||
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PROLOGUE_FLAG_VAR_NAME])
|
||||
config_dir, experiment_name, trial_name = prepare_hydra_config(
|
||||
args["cmd"], args[PROLOGUE_FLAG_VAR_NAME]
|
||||
)
|
||||
sys.argv.remove(PROLOGUE_FLAG_NAME)
|
||||
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
|
||||
sys.argv += [f"--config-path", f"{config_dir}"]
|
||||
|
@ -61,7 +69,9 @@ def main():
|
|||
experiment_name = get_experiment_name()
|
||||
trial_name = get_trial_name()
|
||||
|
||||
launch_hydra_task(args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]])
|
||||
launch_hydra_task(
|
||||
args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]]
|
||||
)
|
||||
|
||||
|
||||
def prepare_hydra_config(name: str, prologue_path: str):
|
||||
|
@ -76,7 +86,10 @@ def prepare_hydra_config(name: str, prologue_path: str):
|
|||
|
||||
return (config_dir, experiment_name, trial_name)
|
||||
|
||||
def launch_hydra_task(name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction):
|
||||
|
||||
def launch_hydra_task(
|
||||
name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction
|
||||
):
|
||||
# Disable hydra logging.
|
||||
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
|
||||
sys.argv += ["hydra/job_logging=disabled"]
|
||||
|
|
|
@ -2,17 +2,19 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import sys
|
||||
import datetime
|
||||
import os
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
PROLOGUE_FLAG_NAME = "--config"
|
||||
PROLOGUE_FLAG_VAR_NAME = "config"
|
||||
PROLOGUE_EXTERNAL_CONFIG_NAME = "external_configs"
|
||||
|
||||
|
||||
def global_init():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument(PROLOGUE_FLAG_NAME)
|
||||
|
@ -53,6 +55,7 @@ def global_init():
|
|||
json.dump(cluster_spec, f)
|
||||
os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path
|
||||
|
||||
|
||||
def get_experiment_name(default_name: str = ""):
|
||||
if any("experiment_name=" in x for x in sys.argv):
|
||||
experiment_name = next(x for x in sys.argv if "experiment_name=" in x).split(
|
||||
|
@ -60,13 +63,14 @@ def get_experiment_name(default_name: str = ""):
|
|||
)[1]
|
||||
else:
|
||||
experiment_name = default_name
|
||||
if experiment_name == "":
|
||||
if experiment_name == "":
|
||||
experiment_name = f"quickstart-experiment"
|
||||
|
||||
if "_" in experiment_name:
|
||||
raise RuntimeError("experiment_name should not contain `_`.")
|
||||
return experiment_name
|
||||
|
||||
|
||||
def get_trial_name(default_name: str = ""):
|
||||
if any("trial_name=" in x for x in sys.argv):
|
||||
trial_name = next(x for x in sys.argv if "trial_name=" in x).split("=")[1]
|
||||
|
|
Loading…
Reference in New Issue