mirror of https://github.com/inclusionAI/AReaL
fix arg parse
This commit is contained in:
parent
57ce1213ae
commit
97511e43ff
|
@ -629,8 +629,15 @@ class GRPOConfig(BaseExperimentConfig):
|
|||
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
|
||||
|
||||
def parse_cli_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
@dataclass
|
||||
class ArgParseResult:
|
||||
config: BaseExperimentConfig
|
||||
config_file: Path
|
||||
additional_args: Optional[argparse.Namespace] = None
|
||||
|
||||
def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] = None) -> ArgParseResult:
|
||||
if parser is None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
|
@ -646,8 +653,11 @@ def parse_cli_args(argv: List[str]):
|
|||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
overrides=overrides,
|
||||
)
|
||||
return cfg, config_file
|
||||
|
||||
return ArgParseResult(
|
||||
config=cfg,
|
||||
config_file=config_file,
|
||||
additional_args=args
|
||||
)
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
# Merge with the default configuration.
|
||||
|
@ -657,8 +667,13 @@ def to_structured_cfg(cfg, config_cls):
|
|||
return cfg
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
cfg, config_file = parse_cli_args(argv)
|
||||
def load_expr_config(
|
||||
argv: List[str],
|
||||
config_cls,
|
||||
parser: Optional[argparse.ArgumentParser] = None
|
||||
) -> ArgParseResult:
|
||||
r = parse_cli_args(argv, parser=parser)
|
||||
cfg = r.config
|
||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
@ -668,4 +683,8 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
|
|||
|
||||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
return cfg, str(config_file)
|
||||
return ArgParseResult(
|
||||
config=cfg,
|
||||
config_file=r.config_file,
|
||||
additional_args=r.additional_args
|
||||
)
|
||||
|
|
|
@ -243,7 +243,8 @@ class LocalLauncher:
|
|||
|
||||
|
||||
def main_local():
|
||||
cfg, _ = parse_cli_args(sys.argv[2:])
|
||||
r = parse_cli_args(sys.argv[2:])
|
||||
cfg = r.config
|
||||
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
|
||||
|
||||
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
|
||||
|
|
|
@ -5,10 +5,11 @@ import re
|
|||
import os
|
||||
import getpass
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig
|
||||
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig, parse_cli_args, to_structured_cfg
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.launcher.utils import find_config, find_and_amend_config
|
||||
import realhf.base.logging as logging
|
||||
|
@ -469,7 +470,7 @@ class SlurmLauncher:
|
|||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def slurm_args_parser():
|
||||
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
|
||||
parser.add_argument("entry_point", type=str, help="The entry point script to run.")
|
||||
parser.add_argument("config_path", type=str, help="Path to the configuration file.")
|
||||
|
@ -478,25 +479,19 @@ def parse_args():
|
|||
help="Base port for SGLang servers. SGLang servers on the same node will ."
|
||||
)
|
||||
parser.add_argument("--trainer-port", type=int, required=False, default=27009, help="Pytorch distributed initialization port for trainer.")
|
||||
parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
|
||||
# parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
|
||||
|
||||
return parser.parse_args()
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> --allocation_mode <allocation_mode> <config_path> [<args>]
|
||||
args = parse_args()
|
||||
config = OmegaConf.load(args.config_path)
|
||||
# Fix config with remaining args
|
||||
config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.remaining_args))
|
||||
|
||||
r = parse_cli_args(sys.argv[2:], parser=slurm_args_parser())
|
||||
config = r.config
|
||||
args = r.additional_args
|
||||
|
||||
cluster_config: ClusterSpecConfig = find_and_amend_config(config, "cluster", ClusterSpecConfig)
|
||||
assert cluster_config is not None, "Cluster configuration is required for slurm launcher."
|
||||
|
||||
n_nodes = find_config(config, "n_nodes")
|
||||
n_gpus_per_node = find_config(config, "n_gpus_per_node")
|
||||
assert n_gpus_per_node is not None and isinstance(n_gpus_per_node, int)
|
||||
assert n_nodes is not None and isinstance(n_nodes, int)
|
||||
cluster_config: ClusterSpecConfig = config.cluster
|
||||
n_nodes = config.n_nodes
|
||||
n_gpus_per_node = config.n_gpus_per_node
|
||||
if n_gpus_per_node < cluster_config.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"Slurm Launcher requires at least {cluster_config.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||
|
@ -507,12 +502,11 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
launcher = SlurmLauncher(
|
||||
experiment_name=find_config(config, "experiment_name"),
|
||||
trial_name=find_config(config, "trial_name"),
|
||||
experiment_name=config.experiment_name,
|
||||
trial_name=config.trial_name,
|
||||
fileroot=cluster_config.fileroot
|
||||
)
|
||||
|
||||
allocation_mode = find_config(config, "allocation_mode")
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
n_sglang_nodes = 0
|
||||
|
@ -566,10 +560,9 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
entry_point_cmd = f"{args.entry_point} --config {args.config_path} {' '.join(args.remaining_args)}"
|
||||
trainer_cmd_template = (
|
||||
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point_cmd}"
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {' '.join(sys.argv[1:])}"
|
||||
)
|
||||
|
||||
trainer_cmds = []
|
||||
|
|
|
@ -72,8 +72,8 @@ def gsm8k_reward_fn(
|
|||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig
|
||||
r = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig = r.config
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
|
|
|
@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
|||
|
||||
|
||||
def main_sft():
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig
|
||||
r = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig = r.config
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
|
|
Loading…
Reference in New Issue