fix arg parse

This commit is contained in:
晓雷 2025-07-11 16:47:29 +08:00
parent 57ce1213ae
commit 97511e43ff
5 changed files with 47 additions and 34 deletions

View File

@ -629,8 +629,15 @@ class GRPOConfig(BaseExperimentConfig):
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
def parse_cli_args(argv: List[str]):
parser = argparse.ArgumentParser()
@dataclass
class ArgParseResult:
config: BaseExperimentConfig
config_file: Path
additional_args: Optional[argparse.Namespace] = None
def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] = None) -> ArgParseResult:
if parser is None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
)
@ -646,8 +653,11 @@ def parse_cli_args(argv: List[str]):
config_name=str(relpath.name).rstrip(".yaml"),
overrides=overrides,
)
return cfg, config_file
return ArgParseResult(
config=cfg,
config_file=config_file,
additional_args=args
)
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
@ -657,8 +667,13 @@ def to_structured_cfg(cfg, config_cls):
return cfg
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
cfg, config_file = parse_cli_args(argv)
def load_expr_config(
argv: List[str],
config_cls,
parser: Optional[argparse.ArgumentParser] = None
) -> ArgParseResult:
r = parse_cli_args(argv, parser=parser)
cfg = r.config
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
@ -668,4 +683,8 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
return cfg, str(config_file)
return ArgParseResult(
config=cfg,
config_file=r.config_file,
additional_args=r.additional_args
)

View File

@ -243,7 +243,8 @@ class LocalLauncher:
def main_local():
cfg, _ = parse_cli_args(sys.argv[2:])
r = parse_cli_args(sys.argv[2:])
cfg = r.config
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)

View File

@ -5,10 +5,11 @@ import re
import os
import getpass
import argparse
import sys
from omegaconf import OmegaConf
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig, parse_cli_args, to_structured_cfg
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.launcher.utils import find_config, find_and_amend_config
import realhf.base.logging as logging
@ -469,7 +470,7 @@ class SlurmLauncher:
)
def parse_args():
def slurm_args_parser():
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
parser.add_argument("entry_point", type=str, help="The entry point script to run.")
parser.add_argument("config_path", type=str, help="Path to the configuration file.")
@ -478,25 +479,19 @@ def parse_args():
help="Base port for SGLang servers. SGLang servers on the same node will ."
)
parser.add_argument("--trainer-port", type=int, required=False, default=27009, help="Pytorch distributed initialization port for trainer.")
parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
# parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
return parser.parse_args()
return parser
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> --allocation_mode <allocation_mode> <config_path> [<args>]
args = parse_args()
config = OmegaConf.load(args.config_path)
# Fix config with remaining args
config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.remaining_args))
r = parse_cli_args(sys.argv[2:], parser=slurm_args_parser())
config = r.config
args = r.additional_args
cluster_config: ClusterSpecConfig = find_and_amend_config(config, "cluster", ClusterSpecConfig)
assert cluster_config is not None, "Cluster configuration is required for slurm launcher."
n_nodes = find_config(config, "n_nodes")
n_gpus_per_node = find_config(config, "n_gpus_per_node")
assert n_gpus_per_node is not None and isinstance(n_gpus_per_node, int)
assert n_nodes is not None and isinstance(n_nodes, int)
cluster_config: ClusterSpecConfig = config.cluster
n_nodes = config.n_nodes
n_gpus_per_node = config.n_gpus_per_node
if n_gpus_per_node < cluster_config.n_gpus_per_node:
raise ValueError(
f"Slurm Launcher requires at least {cluster_config.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
@ -507,12 +502,11 @@ if __name__ == "__main__":
)
launcher = SlurmLauncher(
experiment_name=find_config(config, "experiment_name"),
trial_name=find_config(config, "trial_name"),
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=cluster_config.fileroot
)
allocation_mode = find_config(config, "allocation_mode")
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
n_sglang_nodes = 0
@ -566,10 +560,9 @@ if __name__ == "__main__":
)
trainer_n_nodes = n_nodes - n_sglang_nodes
entry_point_cmd = f"{args.entry_point} --config {args.config_path} {' '.join(args.remaining_args)}"
trainer_cmd_template = (
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point_cmd}"
f"--master-addr $head_node_ip --master-port {args.trainer_port} {' '.join(sys.argv[1:])}"
)
trainer_cmds = []

View File

@ -72,8 +72,8 @@ def gsm8k_reward_fn(
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
config: GRPOConfig
r = load_expr_config(sys.argv[1:], GRPOConfig)
config: GRPOConfig = r.config
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))

View File

@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
def main_sft():
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
config: SFTConfig
r = load_expr_config(sys.argv[1:], SFTConfig)
config: SFTConfig = r.config
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))