mirror of https://github.com/inclusionAI/AReaL
171 lines
5.2 KiB
Python
171 lines
5.2 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import argparse
|
|
import datetime
|
|
import getpass
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import sys
|
|
|
|
import hydra
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from rich.panel import Panel
|
|
|
|
from realhf.api.cli_args import console, highlighter, print_config_help
|
|
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN
|
|
from realhf.base.cluster import spec as cluster_spec
|
|
from realhf.base.importing import import_module
|
|
from realhf.base.prologue import (
|
|
PROLOGUE_EXTERNAL_CONFIG_NAME,
|
|
PROLOGUE_FLAG_NAME,
|
|
PROLOGUE_FLAG_VAR_NAME,
|
|
get_experiment_name,
|
|
get_trial_name,
|
|
)
|
|
from realhf.version import get_full_version_with_dirty_description
|
|
|
|
# NOTE: Register all implemented experiments inside ReaL.
|
|
import_module(
|
|
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "common"),
|
|
re.compile(r".*_exp\.py$"),
|
|
)
|
|
import_module(
|
|
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "async_exp"),
|
|
re.compile(r".*_exp\.py$"),
|
|
)
|
|
import realhf.experiments.benchmark.profile_exp
|
|
|
|
|
|
def print_help(exp_type):
|
|
"""Print comprehensive help with rich formatting"""
|
|
config_class = QUICKSTART_CONFIG_CLASSES[exp_type]()
|
|
|
|
# Main help panel
|
|
console.print(
|
|
Panel.fit(
|
|
f"[header]Configuration Help for {exp_type}[/header]", border_style="border"
|
|
)
|
|
)
|
|
|
|
# Configuration options section
|
|
console.print("\n[title]CONFIGURATION OPTIONS[/title]")
|
|
print_config_help(config_class)
|
|
|
|
# Usage section
|
|
console.print("\n[title]USAGE[/title]")
|
|
usage_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml [OPTIONS]"
|
|
console.print(highlighter(usage_code))
|
|
|
|
# Examples section
|
|
console.print("\n[title]EXAMPLE OVERRIDES[/title]")
|
|
example_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml dataset.path=/my/dataset.jsonl actor.optimizer.lr=2e-5"
|
|
console.print(highlighter(example_code))
|
|
|
|
# Footer
|
|
console.print("\n[dim]Use [bold]--help[/bold] to show this message again[/dim]")
|
|
|
|
|
|
def print_version():
|
|
console.print(f"AReaL Version: {get_full_version_with_dirty_description()}")
|
|
|
|
|
|
def main():
|
|
# Create parser with add_help=False to disable automatic --help
|
|
parser = argparse.ArgumentParser(prog="ReaL Quickstart", add_help=False)
|
|
|
|
# Add custom help argument that won't conflict
|
|
parser.add_argument(
|
|
"--help", action="store_true", help="Show this help message and exit"
|
|
)
|
|
parser.add_argument("--version", action="store_true", help="Show AReaL version")
|
|
|
|
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
|
|
subparsers.required = True
|
|
|
|
for k, v in QUICKSTART_FN.items():
|
|
# Create subparser with add_help=False
|
|
subparser = subparsers.add_parser(k, add_help=False)
|
|
|
|
# Add custom help to subparser
|
|
subparser.add_argument(
|
|
"--help", action="store_true", help="Show help for this command"
|
|
)
|
|
subparser.add_argument(
|
|
PROLOGUE_FLAG_NAME,
|
|
type=str,
|
|
help="Set config (*.yaml) for this experiment.",
|
|
)
|
|
|
|
subparser.set_defaults(func=v)
|
|
|
|
# Parse known args first to check for help
|
|
args = vars(parser.parse_known_args()[0])
|
|
|
|
if args["version"]:
|
|
print_version()
|
|
return
|
|
|
|
# Handle help at both main and subcommand levels
|
|
if args["help"]:
|
|
if args["cmd"]:
|
|
# Subcommand help
|
|
print_help(args["cmd"])
|
|
else:
|
|
# Main help
|
|
parser.print_help()
|
|
return
|
|
|
|
# Continue with normal execution
|
|
if not args["cmd"]:
|
|
parser.print_help()
|
|
experiment_name = ""
|
|
trial_name = ""
|
|
|
|
if args[PROLOGUE_FLAG_VAR_NAME]:
|
|
config_dir, experiment_name, trial_name = prepare_hydra_config(
|
|
args["cmd"], args[PROLOGUE_FLAG_VAR_NAME]
|
|
)
|
|
sys.argv.remove(PROLOGUE_FLAG_NAME)
|
|
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
|
|
sys.argv += [f"--config-path", f"{config_dir}"]
|
|
else:
|
|
experiment_name = get_experiment_name()
|
|
trial_name = get_trial_name()
|
|
|
|
launch_hydra_task(
|
|
args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]]
|
|
)
|
|
|
|
|
|
def prepare_hydra_config(name: str, prologue_path: str):
|
|
config = OmegaConf.load(prologue_path)
|
|
experiment_name = get_experiment_name(config.get("experiment_name"))
|
|
trial_name = get_trial_name(config.get("trial_name"))
|
|
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
|
os.makedirs(config_dir, exist_ok=True)
|
|
|
|
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME, {})
|
|
with open(f"{config_dir}/{name}.yaml", "w") as f:
|
|
f.write(OmegaConf.to_yaml(config))
|
|
|
|
return (config_dir, experiment_name, trial_name)
|
|
|
|
|
|
def launch_hydra_task(
|
|
name: str, experiment_name: str, trial_name: str, func: hydra.TaskFunction
|
|
):
|
|
# Disable hydra logging.
|
|
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
|
|
sys.argv.insert(2, "hydra/job_logging=disabled")
|
|
|
|
sys.argv.pop(1)
|
|
|
|
func()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|