mirror of https://github.com/inclusionAI/AReaL
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import argparse
|
|
import datetime
|
|
import getpass
|
|
import pathlib
|
|
import re
|
|
import sys
|
|
import os
|
|
|
|
import hydra
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from realhf.base.prologue import PROLOGUE_FLAG_NAME, PROLOGUE_FLAG_VAR_NAME, PROLOGUE_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
|
|
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
|
|
from realhf.base.cluster import spec as cluster_spec
|
|
from realhf.base.importing import import_module
|
|
|
|
# NOTE: Register all implemented experiments inside ReaL.
|
|
import_module(
|
|
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "common"),
|
|
re.compile(r".*_exp\.py$"),
|
|
)
|
|
import realhf.experiments.benchmark.profile_exp
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(prog="ReaL Quickstart")
|
|
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
|
|
subparsers.required = True
|
|
for k, v in QUICKSTART_FN.items():
|
|
subparser = subparsers.add_parser(k)
|
|
subparser.add_argument(
|
|
"--show-args",
|
|
action="store_true",
|
|
help="Show all legal CLI arguments for this experiment.",
|
|
)
|
|
subparser.add_argument(
|
|
PROLOGUE_FLAG_NAME,
|
|
type=str,
|
|
help="Set config (*.yaml) for this experiment.",
|
|
)
|
|
subparser.set_defaults(func=v)
|
|
args = vars(parser.parse_known_args()[0])
|
|
if args["show_args"]:
|
|
sys.argv = [sys.argv[0], "--help"]
|
|
QUICKSTART_FN[args["cmd"]]()
|
|
return
|
|
|
|
experiment_name = ""
|
|
trial_name = ""
|
|
|
|
if args[PROLOGUE_FLAG_VAR_NAME]:
|
|
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PROLOGUE_FLAG_VAR_NAME])
|
|
sys.argv.remove(PROLOGUE_FLAG_NAME)
|
|
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
|
|
sys.argv += [f"--config-path", f"{config_dir}"]
|
|
else:
|
|
experiment_name = get_experiment_name()
|
|
trial_name = get_trial_name()
|
|
|
|
launch_hydra_task(args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]])
|
|
|
|
|
|
def prepare_hydra_config(name: str, prologue_path: str):
|
|
config = OmegaConf.load(prologue_path)
|
|
experiment_name = get_experiment_name(config.get("experiment_name"))
|
|
trial_name = get_trial_name(config.get("trial_name"))
|
|
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
|
|
|
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME)
|
|
with open(f"{config_dir}/{name}.yaml", "w") as f:
|
|
f.write(OmegaConf.to_yaml(config))
|
|
|
|
return (config_dir, experiment_name, trial_name)
|
|
|
|
def launch_hydra_task(name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction):
|
|
# Disable hydra logging.
|
|
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
|
|
sys.argv += ["hydra/job_logging=disabled"]
|
|
|
|
if (
|
|
"--multirun" in sys.argv
|
|
or "hydra.mode=MULTIRUN" in sys.argv
|
|
or "-m" in sys.argv
|
|
):
|
|
raise NotImplementedError("Hydra multi-run is not supported.")
|
|
|
|
# non-multirun mode, add hydra run dir
|
|
sys.argv += [
|
|
f"hydra.run.dir={cluster_spec.fileroot}/logs/{getpass.getuser()}/"
|
|
f"{experiment_name}/{trial_name}/hydra-outputs/"
|
|
]
|
|
|
|
sys.argv.pop(1)
|
|
|
|
func()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|