mirror of https://github.com/inclusionAI/AReaL
rename preset to prologue
This commit is contained in:
parent
62b9bf3f44
commit
c2b03f62d6
|
@ -13,7 +13,7 @@ import os
|
|||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from realhf.base.preset import PRESET_FLAG_NAME, PRESET_FLAG_VAR_NAME, PRESET_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
|
||||
from realhf.base.prologue import PROLOGUE_FLAG_NAME, PROLOGUE_FLAG_VAR_NAME, PROLOGUE_EXTERNAL_CONFIG_NAME, get_experiment_name, get_trial_name
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.importing import import_module
|
||||
|
@ -38,9 +38,9 @@ def main():
|
|||
help="Show all legal CLI arguments for this experiment.",
|
||||
)
|
||||
subparser.add_argument(
|
||||
PRESET_FLAG_NAME,
|
||||
PROLOGUE_FLAG_NAME,
|
||||
type=str,
|
||||
help="Set preset config (*.yaml) for this experiment.",
|
||||
help="Set config (*.yaml) for this experiment.",
|
||||
)
|
||||
subparser.set_defaults(func=v)
|
||||
args = vars(parser.parse_known_args()[0])
|
||||
|
@ -52,10 +52,10 @@ def main():
|
|||
experiment_name = ""
|
||||
trial_name = ""
|
||||
|
||||
if args[PRESET_FLAG_VAR_NAME]:
|
||||
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PRESET_FLAG_VAR_NAME])
|
||||
sys.argv.remove(PRESET_FLAG_NAME)
|
||||
sys.argv.remove(args[PRESET_FLAG_VAR_NAME])
|
||||
if args[PROLOGUE_FLAG_VAR_NAME]:
|
||||
config_dir, experiment_name, trial_name = prepare_hydra_config(args["cmd"], args[PROLOGUE_FLAG_VAR_NAME])
|
||||
sys.argv.remove(PROLOGUE_FLAG_NAME)
|
||||
sys.argv.remove(args[PROLOGUE_FLAG_VAR_NAME])
|
||||
sys.argv += [f"--config-path", f"{config_dir}"]
|
||||
else:
|
||||
experiment_name = get_experiment_name()
|
||||
|
@ -64,13 +64,13 @@ def main():
|
|||
launch_hydra_task(args["cmd"], experiment_name, trial_name, QUICKSTART_FN[args["cmd"]])
|
||||
|
||||
|
||||
def prepare_hydra_config(name: str, preset_path: str):
|
||||
config = OmegaConf.load(preset_path)
|
||||
def prepare_hydra_config(name: str, PROLOGUE_path: str):
|
||||
config = OmegaConf.load(PROLOGUE_path)
|
||||
experiment_name = get_experiment_name(config.get("experiment_name"))
|
||||
trial_name = get_trial_name(config.get("trial_name"))
|
||||
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
||||
|
||||
config.pop(PRESET_EXTERNAL_CONFIG_NAME)
|
||||
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME)
|
||||
with open(f"{config_dir}/{name}.yaml", "w") as f:
|
||||
f.write(OmegaConf.to_yaml(config))
|
||||
|
||||
|
|
|
@ -9,20 +9,20 @@ import os
|
|||
import json
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
PRESET_FLAG_NAME = "--config"
|
||||
PRESET_FLAG_VAR_NAME = "config"
|
||||
PRESET_EXTERNAL_CONFIG_NAME = "external_configs"
|
||||
PROLOGUE_FLAG_NAME = "--config"
|
||||
PROLOGUE_FLAG_VAR_NAME = "config"
|
||||
PROLOGUE_EXTERNAL_CONFIG_NAME = "external_configs"
|
||||
|
||||
def global_init():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument(PRESET_FLAG_NAME)
|
||||
parser.add_argument(PROLOGUE_FLAG_NAME)
|
||||
args = vars(parser.parse_known_args()[0])
|
||||
if args[PRESET_FLAG_VAR_NAME] is None:
|
||||
if args[PROLOGUE_FLAG_VAR_NAME] is None:
|
||||
return
|
||||
preset_path = args[PRESET_FLAG_VAR_NAME]
|
||||
PROLOGUE_path = args[PROLOGUE_FLAG_VAR_NAME]
|
||||
|
||||
config = OmegaConf.load(preset_path)
|
||||
external_configs = config.get(PRESET_EXTERNAL_CONFIG_NAME)
|
||||
config = OmegaConf.load(PROLOGUE_path)
|
||||
external_configs = config.get(PROLOGUE_EXTERNAL_CONFIG_NAME)
|
||||
|
||||
if external_configs is None:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue