mirror of https://github.com/inclusionAI/AReaL
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
import dataclasses
|
|
import datetime
|
|
import os
|
|
from typing import Dict
|
|
|
|
import hydra
|
|
import yaml
|
|
from omegaconf import MISSING, OmegaConf
|
|
|
|
from realhf.api.quickstart.entrypoint import kind_reminder
|
|
from realhf.base.constants import init_constants
|
|
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
|
from training.utils import run_experiment
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="sync-ppo")
|
|
def main(args):
|
|
# NOTE: we import logging here to avoid hydra logging overwrite
|
|
import realhf.base.logging as logging
|
|
|
|
logger = logging.getLogger("quickstart", "colored")
|
|
|
|
# Overwrite the python dataclass configuration with yaml
|
|
default_args = OmegaConf.structured(PPOMATHConfig)
|
|
args = OmegaConf.merge(default_args, args)
|
|
args: PPOMATHConfig = OmegaConf.to_object(args)
|
|
|
|
# Set experiment trial name.
|
|
exp_name = args.experiment_name
|
|
if args.trial_name == MISSING:
|
|
args.trial_name = trial_name = (
|
|
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
|
)
|
|
else:
|
|
trial_name = args.trial_name
|
|
|
|
if args.mode != "ray":
|
|
raise RuntimeError("This script only supports the `ray` mode.")
|
|
|
|
init_constants(args)
|
|
|
|
from realhf.base.constants import LOG_ROOT
|
|
|
|
# Save overwritten configuration to yaml
|
|
config_save_path = os.path.join(
|
|
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
|
)
|
|
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
|
with open(config_save_path, "w") as f:
|
|
config_dict: Dict = dataclasses.asdict(args)
|
|
yaml.dump(
|
|
config_dict,
|
|
f,
|
|
default_flow_style=False,
|
|
sort_keys=False,
|
|
)
|
|
|
|
kind_reminder("ppo-math", logger, args)
|
|
|
|
run_experiment(args, exp_name, trial_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument("--help", action="store_true")
|
|
args = parser.parse_args()
|
|
if args.help:
|
|
from realhf.api.cli_args import print_config_help
|
|
|
|
print_config_help(PPOMATHConfig())
|
|
exit(0)
|
|
main()
|