[Feature] [lite] Merge from internal dev repo (#189)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* format

* format

* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-21 12:52:43 +08:00 committed by GitHub
parent f68a4f677d
commit 18f8a056b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2023 additions and 153 deletions

View File

@ -676,20 +676,6 @@ class ClusterSpecConfig:
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
gpu_type: str = field(
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
gpu_infer_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
node_name_prefix: str = field(
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
)
n_nodes: int = field(
default=32,
metadata={
@ -725,6 +711,72 @@ class DatasetConfig:
drop_last: bool = field(default=True)
@dataclass
class SlurmLauncherConfig:
"""Configuration for launching the SGLang server with Slurm."""
srun_additional_args: str = field(
default="--overlap --mpi=pmi2 -K --chdir $PWD",
metadata={"help": "Additional arguments to pass to the srun command."},
)
container_type: str = field(
default="apptainer",
metadata={
"help": "Type of containers used in slurm",
"choices": ["apptainer", "none"],
},
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
trainer_image: str = field(
default="", metadata={"help": "slurm image for trainers."}
)
inference_server_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""
inference_server_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
)
inference_server_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
)
trainer_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for training. "},
)
trainer_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for training in MB. "},
)
inference_server_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for inference server, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for training, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
slurm: SlurmLauncherConfig = field(
default_factory=SlurmLauncherConfig,
metadata={"help": "Slurm launcher configuration."},
)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
@ -742,12 +794,6 @@ class BaseExperimentConfig:
default_factory=ClusterSpecConfig,
metadata={"help": "Cluster specification. Mainly used by slurm."},
)
n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."}
)
n_gpus_per_node: int = field(
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
)
allocation_mode: str = field(
default="",
metadata={
@ -785,6 +831,7 @@ class BaseExperimentConfig:
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)
@dataclass

View File

@ -266,6 +266,7 @@ class BaseHFEngine(TrainEngine):
# Scale loss for accumulation
# Revert gradient averaging across dp ranks
# FIXME: should be DP size
loss_scale *= self.world_size
loss *= loss_scale
@ -286,8 +287,6 @@ class BaseHFEngine(TrainEngine):
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),

View File

@ -1,10 +1,7 @@
import dis
import gc
import os
import threading
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
@ -14,14 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
@ -232,6 +222,7 @@ class FSDPEngine(BaseHFEngine):
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_requires_gradient_sync(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
@ -258,8 +249,6 @@ class FSDPEngine(BaseHFEngine):
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),

View File

@ -58,15 +58,9 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
m.numel() - m.count_nonzero()
)
m = loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]]
logp = logprobs[cu_seqlens[i] : cu_seqlens[i + 1]]
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (m.count_nonzero())
## Loggin stats
stats_tracker.denominator(

410
arealite/launcher/ray.py Normal file
View File

@ -0,0 +1,410 @@
import getpass
import importlib.util
import os
import pathlib
import sys
import time
from typing import Dict, List, Optional
import ray
import ray.exceptions
from ray.runtime_env import RuntimeEnv
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
import realhf.base.logging as logging
from arealite.api.cli_args import (
ClusterSpecConfig,
LauncherConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import (
get_env_vars,
validate_config_for_distributed_launcher,
wait_sglang_server_addrs,
)
from arealite.utils.ray import get_placement_group_master_ip_and_port
from realhf.base import logging, name_resolve, names
from realhf.scheduler.client import JobException, JobState
logger = logging.getLogger("RayLauncher")
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
DEFAULT_MAIN_FUNC_NAME = "main"
def run_func(file_path, function_name, *args, **kwargs):
# Convert the file path to a module name
module_name = file_path.replace("/", "_").replace(".", "_")
# Load the module from file path
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# Get the function and execute it
try:
function = getattr(module, function_name)
except AttributeError as e:
raise ValueError(
f"Function '{function_name}' not found in module '{module_name}'. "
f"Please ensure the name of the main function in your entry point "
f"is '{function_name}'."
) from e
return function(*args, **kwargs)
class RayLauncher:
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
# job_name to ray future
self.jobs = {}
@property
def run_name(self):
return f"{self.experiment_name}_{self.trial_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def submit(
self,
job_name: str,
file_path: str,
func_name: str,
args: List[str], # arguments to pass to the function
gpus: int,
cpus: int,
mem: int, # MB
env_vars: Optional[Dict] = None,
placement_group: Optional[PlacementGroup] = None,
bundle_index: int = -1,
kwargs: Optional[
Dict[str, str]
] = None, # keyword arguments to pass to the function
):
if kwargs is None:
kwargs = {}
runtime_env = RuntimeEnv(
env_vars=env_vars or dict(),
)
scheduling_strategy = (
PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_index,
placement_group_capture_child_tasks=True,
)
if placement_group is not None
else "DEFAULT"
)
future = ray.remote(
num_cpus=cpus,
num_gpus=gpus,
memory=mem * 1024 * 1024, # Convert MB to bytes
runtime_env=runtime_env,
scheduling_strategy=scheduling_strategy,
)(run_func).remote(file_path, func_name, *args, **kwargs)
self.jobs[job_name] = future
return future
def submit_array(
self,
job_name: str,
file_path: str,
func_name: str,
count: int,
nodes: int,
list_args: List[List],
gpus_per_task: int,
cpus_per_task: int,
mem_per_task: int, # MB
list_kwargs: List[Dict] | None = None,
env_vars: Optional[Dict] = None,
amend_torch_dist_env: bool = False,
):
"""Submit an array of jobs to Ray with ray placement groups.
Note: Here we use `ray.remote` instead of `ray job submit` since `ray job submit`
does not support placement groups, and can not specify which node to run the job on.
Therefore we could not know the IP address of jobs for torch distributed initialization.
"""
if count % nodes != 0:
raise ValueError(
f"Count {count} is not divisible by nodes {nodes}. "
"Please ensure that count is a multiple of nodes."
)
assert (
len(list_args) == count
), f"Length of list_args {len(list_args)} does not match count {count}."
if list_kwargs is not None:
assert (
len(list_kwargs) == count
), f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
tasks_per_node = count // nodes
gpus_per_node = gpus_per_task * tasks_per_node
cpus_per_node = cpus_per_task * tasks_per_node
mem_per_node = mem_per_task * tasks_per_node
placement_group = ray.util.placement_group(
bundles=[
{
"CPU": cpus_per_node,
"GPU": gpus_per_node,
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
}
]
* nodes,
strategy="STRICT_SPREAD",
)
try:
ray.get(placement_group.ready(), timeout=30)
except ray.exceptions.GetTimeoutError as e:
logger.error(
"Ray placement group timeout, please check if the resource requirement "
"for your experiment exceeds the available resources in the cluster. \n"
f"ray.nodes(): {ray.nodes()} \n"
f"Placement Group bundles: "
f"cpus_per_node={cpus_per_node}, gpus_per_node={gpus_per_node}, "
f"mem_per_node={mem_per_node}MB, nodes={nodes}"
)
raise e
if amend_torch_dist_env:
host_ip, port = get_placement_group_master_ip_and_port(placement_group)
logger.info(
f"Amend torch distributed env vars: "
f"MASTER_ADDR={host_ip}, PORT={port}"
)
futures = []
for i in range(count):
args = list_args[i]
kwargs = list_kwargs[i] if list_kwargs is not None else {}
# manage environment variables
env_vars = env_vars or {}
if "CUDA_VISIBLE_DEVICES" in env_vars:
logger.warning(
"Setting CUDA_VISIBLE_DEVICES before running ray jobs may result in unexpected behavior."
)
node_id = i // tasks_per_node
_env_vars = {
**env_vars,
}
if amend_torch_dist_env:
assert gpus_per_task == 1
# NOTE: Here we only provide environment variables for torch distributed
# initialization, and LOCAL_RANK for torch.device.
# Other environment variables automatically set by torchrun are not set, and
# they should be never accessed in trainer code.
_env_vars.update(
{
"RANK": str(i),
"WORLD_SIZE": str(count),
# Ray will automatically isolate CUDA_VISIBLE_DEVICES for each GPU
"LOCAL_RANK": "0",
"MASTER_ADDR": str(host_ip),
"MASTER_PORT": str(port),
}
)
future = self.submit(
job_name=f"{job_name}:{i}",
file_path=file_path,
func_name=func_name,
args=args,
gpus=gpus_per_task,
cpus=cpus_per_task,
mem=mem_per_task,
env_vars=_env_vars,
placement_group=placement_group,
bundle_index=node_id,
kwargs=kwargs,
)
futures.append(future)
return futures
def stop(self, job_name: str, force: bool = False):
"""Stop a job by name."""
if job_name in self.jobs:
future = self.jobs[job_name]
try:
ray.cancel(future, force=force)
except Exception as e:
logger.error(f"Failed to cancel job {job_name}: {e}")
return
self.jobs.pop(job_name, None)
logger.info(f"Job {job_name} stopped.")
else:
logger.warning(f"Job {job_name} not found in running jobs.")
def stop_all(self, force: bool = False):
"""Stop all jobs."""
for job_name in list(self.jobs.keys()):
self.stop(job_name, force=force)
logger.info("All jobs stopped.")
self.jobs.clear()
def wait(
self, check_status=(JobState.FAILED,), remove_status=(JobState.COMPLETED,)
):
"""Check every RAY_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
If a ray job returns, its status changes to JobState.COMPLETED.
If a ray job failed, its status changes to JobState.FAILED.
If any job is in check_status, stop all jobs at once.
If any job is in remove status, remove them from job list.
Return if all jobs are removed from job list, or some job is in check status.
"""
for status in list(check_status) + list(remove_status):
assert status in [
JobState.COMPLETED,
JobState.FAILED,
], "In RayLauncher.wait, we only check completed or failed jobs."
logger.info(f"Waiting for {len(self.jobs)} jobs.")
while self.jobs:
job_status = {}
for job_name, future in list(self.jobs.items()):
try:
r = ray.get(future, timeout=0.1)
logger.info(f"Job {job_name} completed with result: {r}")
job_status[job_name] = JobState.COMPLETED
except ray.exceptions.RayTaskError as e:
logger.error(f"Job {job_name} failed with error: {e}.")
job_status[job_name] = JobState.FAILED
except ray.exceptions.GetTimeoutError:
continue
for job_name, status in job_status.items():
if status in check_status:
logger.info(f"Job {job_name} is {status}, stopping all jobs.")
self.stop_all(force=True)
return
if status in remove_status:
logger.info(f"Job {job_name} is {status}, removed.")
self.jobs.pop(job_name)
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
def ray_main():
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>]
ray.init()
config, config_file = parse_cli_args(sys.argv[2:])
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
validate_config_for_distributed_launcher(config)
name_resolve.reconfigure(config.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(
experiment_name=config.experiment_name, trial_name=config.trial_name
)
)
n_nodes = config.cluster.n_nodes
n_gpus_per_node = config.cluster.n_gpus_per_node
launcher = RayLauncher(
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=config.cluster.fileroot,
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
sglang_tp_size = allocation_mode.gen_tp_size
n_sglang_servers = allocation_mode.gen_dp_size
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
base_seed = config.sglang.random_seed
sglang_args_list = [
[sys.argv[2:] + [f"sglang.random_seed={base_seed + i}"]]
for i in range(n_sglang_servers)
]
sglang_entry_point = str(
pathlib.Path(__file__).resolve().parent.joinpath("sglang_server.py")
)
launcher.submit_array(
job_name="llm_server",
file_path=sglang_entry_point,
func_name=DEFAULT_MAIN_FUNC_NAME,
count=n_sglang_servers,
nodes=n_sglang_nodes,
list_args=sglang_args_list,
gpus_per_task=sglang_tp_size,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
* sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
# Get SGLang server addresses via name_resolve
try:
sglang_addrs = wait_sglang_server_addrs(
config.experiment_name,
config.trial_name,
n_sglang_servers,
)
except TimeoutError as e:
launcher.stop_all(force=True)
raise e
trainer_n_nodes = n_nodes - n_sglang_nodes
trainer_entry_point = sys.argv[1]
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
trainer_args_list = [[sys.argv[2:]] for _ in range(n_trainer_processes)]
if not config.server_only:
# In ray, we launch trainer in the granularity of processes (1 GPU per process)
# We amend environment variable similar to torchrun to ensure correct initialization of
# torch distributed.
launcher.submit_array(
job_name="trainer",
file_path=trainer_entry_point,
func_name=DEFAULT_MAIN_FUNC_NAME,
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
nodes=trainer_n_nodes,
list_args=trainer_args_list,
gpus_per_task=1,
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
mem_per_task=config.launcher.trainer_mem_per_gpu,
env_vars=dict(
**get_env_vars(
config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
amend_torch_dist_env=True,
)
try:
launcher.wait(check_status=(JobState.COMPLETED, JobState.FAILED))
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all(force=True)
raise e
if __name__ == "__main__":
# usage: python -m arealite.launcher.ray \
# <entry_point> --config <config_path> [<additional_args>] \
# launcher.ray.main_func_name=<main_func_name_in_entry_point>
ray_main()

View File

@ -0,0 +1,197 @@
import os
import subprocess
import sys
import time
import uuid
from pathlib import Path
from typing import Optional
import ray
import requests
from arealite.api.cli_args import (
ClusterSpecConfig,
NameResolveConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import TRITON_CACHE_PATH
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("SGLangServer Wrapper")
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
# Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
_env = os.environ.copy()
# To avoid DirectoryNotEmpty error caused by triton
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4()))
_env["TRITON_CACHE_PATH"] = unique_triton_cache_path
return subprocess.Popen(
parts,
text=True,
env=_env,
stdout=sys.stdout,
stderr=subprocess.STDOUT,
)
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_patch()
process = execute_shell_command(command)
return process
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
time.sleep(5)
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
class SGLangServerWrapper:
def __init__(
self,
experiment_name: str,
trial_name: str,
sglang_config: SGLangConfig,
tp_size: int,
n_gpus_per_node: int,
):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.config = sglang_config
self.tp_size = tp_size
self.server_process = None
self.n_gpus_per_node = n_gpus_per_node
def run(self):
gpus_per_server = len(os.getenv("CUDA_VISIBLE_DEVICES").split(","))
server_local_idx = (
int(os.getenv("CUDA_VISIBLE_DEVICES").split(",")[0]) // gpus_per_server
)
n_servers_per_node = max(1, self.n_gpus_per_node // gpus_per_server)
ports_per_server = 40000 // n_servers_per_node
port_range = (
server_local_idx * ports_per_server + 10000,
(server_local_idx + 1) * ports_per_server + 10000,
)
server_port, dist_init_port = find_free_ports(2, port_range)
dist_init_addr = f"localhost:{dist_init_port}"
host_ip = gethostip()
cmd = SGLangConfig.build_cmd(
self.config,
tp_size=self.tp_size,
base_gpu_id=0,
host=host_ip,
port=server_port,
dist_init_addr=dist_init_addr,
)
self.server_process = launch_server_cmd(cmd)
wait_for_server(f"http://{host_ip}:{server_port}")
name = names.gen_servers(self.experiment_name, self.trial_name)
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")
logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}")
return_code = self.server_process.wait()
logger.info(
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
)
def __del__(self):
if self.server_process and self.server_process.poll() is None:
logger.info("Terminating SGLang server process...")
self.server_process.terminate()
self.server_process.wait()
logger.info("SGLang server process terminated.")
def main(argv):
config, _ = parse_cli_args(argv)
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
config.cluster.name_resolve = to_structured_cfg(
config.cluster.name_resolve, NameResolveConfig
)
name_resolve.reconfigure(config.cluster.name_resolve)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG
tp_size = allocation_mode.gen_tp_size
sglang_server = SGLangServerWrapper(
config.experiment_name,
config.trial_name,
config.sglang,
tp_size,
n_gpus_per_node=config.cluster.n_gpus_per_node,
)
sglang_server.run()
if __name__ == "__main__":
main(sys.argv[1:])

528
arealite/launcher/slurm.py Normal file
View File

@ -0,0 +1,528 @@
import getpass
import os
import re
import subprocess
import sys
import time
from typing import Dict, List, Optional, Tuple
import realhf.base.logging as logging
from arealite.api.cli_args import (
ClusterSpecConfig,
LauncherConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import (
get_env_vars,
validate_config_for_distributed_launcher,
wait_sglang_server_addrs,
)
from arealite.utils.slurm import (
APPTAINER_CMD_TEMPLATE,
SBATCH_SCRIPT_TEMPLATE,
SRUN_CMD_TEMPLATE,
cancel_jobs,
query_jobs,
)
from realhf.base import logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("SlurmLauncher")
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
class SlurmLauncher:
def __init__(
self, experiment_name: str, trial_name: str, fileroot: str, container_type: str
):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
self.container_type = container_type
# slurm_job_id -> JobInfo
self.jobs: Dict[int, JobInfo] = {}
self.job_names = []
@property
def run_name(self) -> str:
"""Returns the run name of this launcher."""
return f"{self.experiment_name}_{self.trial_name}"
def slurm_name(self, job_name: str) -> str:
"""Returns the slurm name of a job."""
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def sbatch_path_of(self, job_name: str) -> str:
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(sbatch_path, exist_ok=True)
return os.path.join(sbatch_path, f"{job_name}.sh")
def submit(self, job_name, cmd, **kwargs):
"""Submits and launch a job with SBATCH.
Args:
cmd (str or List[str]): The core command to be executed.
"""
return self.submit_array(job_name, cmd, count=1, **kwargs)
def find_job_id(self, job_name: str):
job_name = self.slurm_name(job_name)
for job_id, job_info in self.jobs.items():
if job_info.name == job_name:
return job_id
return None
def submit_array(
self,
job_name: str,
cmd: List[str] | str,
count: int,
nodes: int,
n_gpus_per_node: int,
cpus_per_task: int,
mem_per_task: int, # MB
container_image: str,
srun_additional_args: str = "",
container_mounts: Optional[str] = None,
env_vars: Optional[Dict] = None,
nodelist: Optional[str] = None,
exclude: Optional[str] = None,
):
"""Submits and launch a job array with SBATCH.
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
Args:
job_name (str): The job name of the job array. The actual slurm name will be
`<experiment_name>_<trial_name>:<job_name>`.
cmd (str or List[str]): The core command to be executed.
count (int): The number of jobs in the array.
"""
assert job_name not in self.job_names, (
f"Job {job_name} is already submitted. "
"Please use a different job name or stop the existing job."
)
if isinstance(cmd, str):
cmd = [cmd]
assert len(cmd) == count, (
f"Command length {len(cmd)} does not match the job count {count}. "
"Please provide a command for each job in the array."
)
assert count % nodes == 0, (
f"Job count {count} must be divisible by the number of nodes {nodes}. "
"Please adjust the job count or the number of nodes."
)
ntasks_per_node = count // nodes
assert n_gpus_per_node % ntasks_per_node == 0, (
"GPUs must be evenly distributed across tasks. "
f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}."
)
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
mem_per_node = (
mem_per_task * count // nodes + 1024 * 10
) # make sure slurm does not run out of resources
sbatch_options = [
f"--job-name={self.slurm_name(job_name)}",
f"--output={self.log_path_of(job_name)}",
"--open-mode=append",
"--no-requeue",
f"--nodes={nodes}-{nodes}",
f"--ntasks-per-node={ntasks_per_node}",
f"--gres=gpu:{n_gpus_per_node}",
f"--cpus-per-task={cpus_per_task}",
f"--mem={mem_per_node}M",
]
if nodelist:
sbatch_options.append(f"--nodelist={nodelist}")
if exclude:
sbatch_options.append(f"--exclude={exclude}")
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
if env_vars is None:
env_vars = dict()
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
assert (
"CUDA_VISIBLE_DEVICES" not in env_vars
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
srun_cmds = []
for i in range(count):
# resolve CUDA_VISIBLE_DEVICES for each task
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
node_id = i // ntasks_per_node
_env_vars = {
**env_vars,
"CUDA_VISIBLE_DEVICES": ",".join(
str(x) for x in range(gpu_id_start, gpu_id_end)
),
}
# Prepare the command for each job in the array
job_cmd = cmd[i]
if self.container_type == "apptainer":
env_string = " ".join(
"--env {}={}".format(k, v) for k, v in _env_vars.items()
)
apptainer_cmd = APPTAINER_CMD_TEMPLATE.format(
container_mounts=container_mounts or "",
container_env_strings=env_string,
container_image=container_image,
cmd=job_cmd,
)
srun_cmd = SRUN_CMD_TEMPLATE.format(
additional_args=srun_additional_args,
nodes=1,
ntasks=1,
node_id=node_id,
n_gpus_per_node=n_gpus_per_task,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
cmd=apptainer_cmd,
)
elif self.container_type == "none":
env_string = "--export=" + ",".join(
"{}={}".format(k, v) for k, v in _env_vars.items()
)
srun_additional_args = srun_additional_args + " " + env_string
srun_cmd = SRUN_CMD_TEMPLATE.format(
additional_args=srun_additional_args,
nodes=1,
ntasks=1,
node_id=node_id,
n_gpus_per_node=n_gpus_per_task,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
cmd=job_cmd,
)
else:
raise ValueError(
f"Unsupported container type: {self.container_type}. "
"Supported types are 'apptainer' and 'none'."
)
srun_cmds.append(srun_cmd)
srun_cmds = "\n".join(srun_cmds)
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
sbatch_options=sbatch_options_str,
srun_additional_args=srun_additional_args,
srun_cmds=srun_cmds,
)
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
with open(sbatch_file_path, "w") as f:
f.write(sbatch_script)
# Submit the job
try:
output = (
subprocess.check_output(["sbatch", sbatch_file_path])
.decode("utf-8")
.strip()
)
logger.info(
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
)
except subprocess.CalledProcessError as e:
logger.warning(
f"Failed to submit job {self.slurm_name(job_name)}. "
f"For debugging, please make sure your sbatch command works "
f"and check generated sbatch file on {sbatch_file_path}."
)
logger.error(f"Error message: {e}")
return
match = re.search(r"Submitted batch job (\d+)", output)
slurm_job_id = int(match.group(1)) if match else None
if slurm_job_id is None:
logger.warning(
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
f"sbatch output: {output}"
)
return
assert isinstance(slurm_job_id, int)
self.jobs[slurm_job_id] = JobInfo(
name=self.slurm_name(job_name),
state=JobState.PENDING,
slurm_id=slurm_job_id,
)
self._update_all()
def stop(self, job_name, force=False):
"""Stops a running job.
Raises exception if there is no such job, but passes if the job
has stopped either successfully or not.
Args:
job_name: The job name of the job array to stop.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
"""
signal = "SIGKILL" if force else "SIGTERM"
job_id = self.find_job_id(job_name)
if not job_id:
return
return cancel_jobs(slurm_ids=[job_id], signal=signal)
def stop_all(self, force=False):
"""Stops all running jobs."""
signal = "SIGKILL" if force else "SIGTERM"
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
def find(self, job_name) -> JobInfo | None:
"""Gets the status of a job of this job.
Args:
job_name: The job name of the job array to find.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
Returns:
A JobInfo if the job is found, or None otherwise.
"""
self._update_all()
job_id = self.find_job_id(job_name)
return self.jobs[job_id] if job_id else None
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
"""Finds jobs.
Args:
job_name_regex: job name regex.
Returns:
A list of found JobInfo.
"""
self._update_all()
infos = []
for r in self.jobs.values():
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
if re.fullmatch(job_name_regex, job_name):
infos.append(r)
return infos
def _find_job_with_status(
self,
status: List[JobState],
) -> List[JobInfo]:
"""Finds jobs with the given status.
Args:
status: A list of JobState to filter jobs.
Returns:
A list of JobInfo with the given status.
"""
self._update_all()
return [r for r in self.jobs.values() if r.state in status]
def wait(
self,
timeout=None,
check_status: Tuple[JobState, ...] = (
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
),
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
update=False,
):
"""Waits until all jobs submitted via this client instance finish."""
# begin wait
deadline = None if timeout is None else time.time() + timeout
num_jobs_left = len(self.jobs)
left = list(self.jobs.keys())
logger.info(
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
while len(left) > 0:
if len(left) < num_jobs_left:
num_jobs_left = len(left)
logger.info(f"Waiting for {num_jobs_left} jobs.")
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
self._update_all()
left = list(self.jobs.keys())
for slurm_id in list(left):
slurm_info = self.jobs[slurm_id]
if slurm_info.slurm_id is None:
continue
if slurm_info.state in check_status:
raise JobException(
run_name=self.run_name,
worker_type=slurm_info.name,
host=slurm_info.host,
reason=slurm_info.state,
)
if slurm_info.state in remove_status:
logger.info(
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
)
left.remove(slurm_id)
if update:
self.jobs.pop(slurm_info.slurm_id)
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
def _update_all(self):
"""Updates the status of all jobs."""
try:
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
for slurm_info in slurm_infos:
assert slurm_info.slurm_id is not None
self.jobs[slurm_info.slurm_id] = slurm_info
except subprocess.CalledProcessError:
logger.warning(
"Calling squeue failed. Check slurm manually if you continue to see this warning."
)
def slurm_main():
config, _ = parse_cli_args(sys.argv[2:])
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
validate_config_for_distributed_launcher(config)
name_resolve.reconfigure(config.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(
experiment_name=config.experiment_name, trial_name=config.trial_name
)
)
n_nodes = config.cluster.n_nodes
n_gpus_per_node = config.cluster.n_gpus_per_node
launcher = SlurmLauncher(
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=config.cluster.fileroot,
container_type=config.launcher.slurm.container_type,
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
sglang_tp_size = allocation_mode.gen_tp_size
n_sglang_servers = allocation_mode.gen_dp_size
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
base_seed = config.sglang.random_seed
sglang_server_cmd_template = f"python3 -m arealite.launcher.sglang_server {' '.join(sys.argv[2:])} sglang.random_seed={{seed}}"
for i in range(n_sglang_servers):
sglang_cmd = sglang_server_cmd_template.format(
seed=base_seed + i,
)
sglang_cmds.append(sglang_cmd)
launcher.submit_array(
job_name="llm_server",
cmd=sglang_cmds,
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
* sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
srun_additional_args=config.launcher.slurm.srun_additional_args,
container_image=config.launcher.slurm.inference_server_image,
container_mounts=config.launcher.slurm.mount,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
# Get SGLang server addresses by name resolve
try:
sglang_addrs = wait_sglang_server_addrs(
config.experiment_name,
config.trial_name,
n_sglang_servers,
)
except TimeoutError as e:
launcher.stop_all(force=True)
raise e
trainer_n_nodes = n_nodes - n_sglang_nodes
# Here $head_node_ip is the IP address of the first node in the job array.
# $trainer_port is a free port on the head node.
# Both of them are obtained in by the SBATCH script.
trainer_cmd_template = (
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
f"--master-addr $head_node_ip --master-port $trainer_port {' '.join(sys.argv[1:])}"
)
trainer_cmds = []
for i in range(trainer_n_nodes):
# In slurm, we launch trainer in the granularity of nodes with torchrun command.
trainer_cmds.append(
trainer_cmd_template.format(
nnodes=trainer_n_nodes,
nproc_per_node=config.cluster.n_gpus_per_node,
node_rank=i,
)
)
if not config.server_only:
# launch trainers
launcher.submit_array(
job_name="trainer",
cmd=trainer_cmds,
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.trainer_cpus_per_gpu
* config.cluster.n_gpus_per_node,
mem_per_task=config.launcher.trainer_mem_per_gpu
* config.cluster.n_gpus_per_node,
container_image=config.launcher.slurm.trainer_image,
srun_additional_args=config.launcher.slurm.srun_additional_args,
container_mounts=config.launcher.slurm.mount,
env_vars=dict(
**get_env_vars(
config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
)
try:
launcher.wait(
check_status=(
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
JobState.COMPLETED,
),
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all(force=True)
raise e
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> \
# --config <config_path> [<additional_args>]
slurm_main()

View File

@ -2,11 +2,9 @@ import os
import subprocess
import sys
import time
import uuid
import pytest
import requests
import torch
from arealite.api.cli_args import (
InferenceEngineConfig,
@ -18,7 +16,6 @@ from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.network import find_free_ports
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl"

View File

@ -399,9 +399,15 @@ def pad_packed_tensor_dict(
padded_data[key] = new_max_seqlen
elif torch.is_tensor(value) and value.numel() == total_length:
# Pad the tensor to the new total length
padded_tensor = torch.nn.functional.pad(
value, (0, pad_length), value=pad_value
)
if key == "position_ids":
# transformers will compute flash-attn arguments (e.g., cu_seqlens_q)
# according to this position ids.
pad = torch.arange(pad_length, dtype=torch.long, device=value.device)
padded_tensor = torch.cat([value, pad])
else:
padded_tensor = torch.nn.functional.pad(
value, (0, pad_length), value=pad_value
)
padded_data[key] = padded_tensor
else:
padded_data[key] = value

102
arealite/utils/launcher.py Normal file
View File

@ -0,0 +1,102 @@
import getpass
import os
import pathlib
import time
from typing import Dict, Optional
from arealite.api.io_struct import AllocationMode, AllocationType
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("Launcher Utils")
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton/"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"PYTHONPATH": str(pathlib.Path(__file__).resolve().parent.parent.parent),
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
def get_env_vars(
cluster_name: str, additional_env_vars: Optional[str] = None
) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
_additional_env_vars = (
dict(item.split("=") for item in additional_env_vars.split(","))
if additional_env_vars
else dict()
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **_additional_env_vars}
else:
return {**BASE_ENVIRONS, **_additional_env_vars}
def wait_sglang_server_addrs(
experiment_name: str,
trial_name: str,
n_sglang_servers: int,
):
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(experiment_name, trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {len(sglang_addrs)} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
return sglang_addrs
def validate_config_for_distributed_launcher(config):
n_nodes = config.cluster.n_nodes
n_gpus_per_node = config.cluster.n_gpus_per_node
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
assert (
allocation_mode.gen_world_size + allocation_mode.train_world_size
== n_nodes * n_gpus_per_node
), (
f"#GPUs required for allocation mode {allocation_mode.gen_world_size + allocation_mode.train_world_size} "
f"is not equal to #GPUs in the config {n_nodes * n_gpus_per_node}."
)
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
assert (
allocation_mode.gen_pp_size == 1
), "Pipeline generation in SGLang is not supported for now."
assert (
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
), "Currently only support SGLang TP size less <= #GPUs per node."

View File

@ -1,48 +0,0 @@
from typing import List
import torch
from tensordict import TensorDict
def concat_padded_tensors(
tensor_dicts: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return TensorDict()
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
# Find max sequence length across all dictionaries
assert all("attention_mask" in td for td in tensor_dicts)
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
result = {}
# Process each key
for key in tensor_dicts[0].keys():
tensors_to_concat = []
for tensor_dict in tensor_dicts:
tensor = tensor_dict[key]
# Skip 1D tensors like rewards
if len(tensor.shape) == 1:
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if current_length < max_length:
# Pad tensor to max_length
pad_width = max_length - current_length
if key == "attention_mask":
# Pad attention mask with 0s
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
)
tensor = torch.cat([tensor, padding], dim=1)
tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0)
return TensorDict(result, batch_size=new_batch_size)

23
arealite/utils/ray.py Normal file
View File

@ -0,0 +1,23 @@
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from arealite.utils.network import find_free_ports, gethostip
def get_placement_group_master_ip_and_port(placement_group: PlacementGroup):
def _master_ip_and_port():
host_ip = gethostip()
port = find_free_ports(1, (10000, 60000))[0]
return host_ip, port
future = ray.remote(
num_cpus=1,
num_gpus=0,
memory=10 * 1024 * 1024, # Convert MB to bytes
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=0,
),
)(_master_ip_and_port).remote()
return ray.get(future)

154
arealite/utils/slurm.py Normal file
View File

@ -0,0 +1,154 @@
import subprocess
from typing import List, Literal, Optional
from realhf.base import logging
from realhf.scheduler.client import JobInfo, JobState
logger = logging.getLogger("Slurm Utils")
SQUEUE_FIELDS = [
"JobID",
"State",
"SubmitTime",
"StartTime",
"Name",
"NodeList",
"UserName",
"MaxCPUs",
"cpus-per-task",
"NumTasks",
"tres-alloc",
]
STATUS_MAPPING = {
"RUNNING": JobState.RUNNING,
"COMPLETING": JobState.RUNNING,
"PENDING": JobState.PENDING,
"CANCELLED": JobState.CANCELLED,
"FAILED": JobState.FAILED,
"COMPLETED": JobState.COMPLETED,
"OUT_OF_MEMORY": JobState.FAILED,
"DEADLINE": JobState.COMPLETED,
"TIMEOUT": JobState.COMPLETED,
}
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
{sbatch_options}
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
echo nodes=$nodes
nodes_array=($nodes)
echo node_array=$nodes_array
head_node=${{nodes_array[0]}}
echo head_node=$head_node
# Getting the head node IP address
head_node_ip=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
echo head_node_ip=$head_node_ip
# Find a free port on the head node
# Wonderful linux command to find a random free port (between 10000 and 60000) by deepseek
trainer_port=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" bash -c "comm -23 <(seq 10000 60000 | sort) <(ss -tan | awk '{{print $4}}' | cut -d':' -f2 | grep '[0-9]\\{{1,5\\}}' | sort -u) | shuf | head -n 1")
echo trainer_port=$trainer_port
# srun commands
{srun_cmds}
wait
"""
SRUN_CMD_TEMPLATE: str = """srun {additional_args} \\
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
{cmd} &
"""
APPTAINER_CMD_TEMPLATE: str = """singularity exec --no-home --writable-tmpfs --nv --pid \\
--bind {container_mounts} \\
{container_env_strings} \\
{container_image} \\
{cmd}"""
def cancel_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
):
assert (
slurm_names is not None or slurm_ids is not None
), "Must specify slurm_names or slurm_ids."
assert not (
slurm_names and slurm_ids
), "Cannot specify both slurm_names and slurm_ids."
cmd = ["scancel", "-s", signal]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
elif slurm_ids is not None:
cmd += [",".join(str(s) for s in slurm_ids)]
subprocess.check_call(cmd)
logger.info(
f"Cancelled Slurm job with signal {signal}: "
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
)
def query_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
status: str = "all",
delimiter: str = "__PSI__",
) -> List[JobInfo]:
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
if slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
output = (
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
)
rs = []
for line in output.split("\n")[1:]:
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
delimiter
)
rs.append(
JobInfo(
name=slurm_name,
state=STATUS_MAPPING[state],
host=nodelist,
submit_time=submit_time,
start_time=start_time,
slurm_id=int(job_id.strip()),
)
)
return rs
def parse_slurm_nodelist(nodelist: str) -> List[str]:
return (
subprocess.check_output(
[
"scontrol",
"show",
"hostnames",
nodelist,
]
)
.decode("utf-8")
.strip()
.split("\n")
)
def get_slurm_host_ip(node: str, srun_addtional_args: str):
try:
cmd = f"srun {srun_addtional_args} --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
except subprocess.CalledProcessError:
logger.warning(f"Get slurm host IP for node {node} failed.")

View File

@ -73,7 +73,7 @@ class StatsLogger:
)
if isinstance(data, Dict):
data = [data]
log_step = max(global_step, self._last_commit_step)
log_step = max(global_step, self._last_commit_step + 1)
for i, item in enumerate(data):
self.info(f"Stats ({i+1}/{len(data)}):")
self.print_stats(item)

View File

@ -1,11 +1,14 @@
import asyncio
import os
import uuid
import colorama
import torch
from tensordict import TensorDict
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.data import concat_padded_tensors
@ -18,13 +21,17 @@ class RLVRWorkflow(RolloutWorkflow):
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
enable_thinking: bool,
dump_dir: str | None = None,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
self.dump_dir = dump_dir
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)
async def arun_episode(self, engine, data):
async def arun_episode(self, engine: InferenceEngine, data):
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
@ -39,6 +46,12 @@ class RLVRWorkflow(RolloutWorkflow):
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
@ -46,13 +59,19 @@ class RLVRWorkflow(RolloutWorkflow):
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
prompt=self.tokenizer.decode(input_ids),
completions=self.tokenizer.decode(resp.output_tokens),
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
@ -65,4 +84,31 @@ class RLVRWorkflow(RolloutWorkflow):
)
results.append(TensorDict(res, batch_size=[1]))
if self.dump_dir is not None:
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
# Get the unique identifier for this prompt
qid = None
for key in ["query_id", "id", "qid"]:
qid = data.get(key, None)
if qid is not None:
break
qid = qid or uuid.uuid4().hex
# Dump rollout to file
with open(
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
) as f:
n_samples = self.gconfig.n_samples
for i, (p, c, r, sl) in enumerate(
zip(prompt_strs, completions_strs, rewards, seqlens)
):
info = "\n".join(
[
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
]
)
f.write(info + "\n")
return concat_padded_tensors(results)

310
examples/arealite/boba.py Normal file
View File

@ -0,0 +1,310 @@
import asyncio
import os
import shutil
import sys
import uuid
import colorama
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import (
GenerationHyperparameters,
GRPOConfig,
load_expr_config,
)
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
from arealite.api.workflow_api import RolloutWorkflow
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.data import concat_padded_tensors
from arealite.utils.device import log_gpu_stats
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("boba math")
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
dump_dir: str | None = None,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.dump_dir = dump_dir
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)
async def arun_episode(self, engine, data):
input_ids = self.tokenizer.encode(data["prompt"])
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
prompt_str = data["prompt"]
completions_str = self.tokenizer.decode(resp.output_tokens)
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([float(reward)]),
)
results.append(TensorDict(res, batch_size=[1]))
if self.dump_dir is not None:
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
# Get the unique identifier for this prompt
qid = None
for key in ["query_id", "id", "qid"]:
qid = data.get(key, None)
if qid is not None:
break
qid = qid or uuid.uuid4().hex
# Dump rollout to file
with open(
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
) as f:
n_samples = self.gconfig.n_samples
for i, (p, c, r, sl) in enumerate(
zip(prompt_strs, completions_strs, rewards, seqlens)
):
info = "\n".join(
[
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
]
)
f.write(info + "\n")
return concat_padded_tensors(results)
def get_boba_math_dataset(tokenizer, rank, world_size):
dataset = load_dataset(
path="json",
split="train",
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
)
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def boba_reward_fn(
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
):
from pebble import ProcessExpired, ProcessPool
from realhf.impl.dataset.math_parser import process_results
jobs = []
with ProcessPool(max_workers=1) as executor:
for sol in solutions:
job = executor.schedule(
process_results, args=[completions, sol], timeout=15
)
jobs.append(job)
label = 0
for job in jobs:
try:
x = job.result()
except TimeoutError:
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
return label
def main(args):
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_boba_math_dataset(tokenizer, rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=boba_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier()
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
log_gpu_stats("recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
log_gpu_stats("compute advantages")
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
path = os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
shutil.rmtree(path, ignore_errors=True)
dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
logger.commit(epoch, step, global_step, stats)
logger.close()
rollout.destroy()
if ref is not None:
ref.destroy()
actor.destroy()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -0,0 +1,141 @@
experiment_name: lite-boba-math
trial_name: run1
cluster:
n_nodes: 16
n_gpus_per_node: 8
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
seed: 1
total_train_epochs: 10
total_train_steps: null
tokenizer_path: ${actor.path}
allocation_mode: sglang.d96p1t1+d32p1t1
async_training: true
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 400
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: true
gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 30720
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 32768
optimizer:
type: adam
lr: 1e-5
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 4
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 32768
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 512
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: online
# Launcher
launcher:
inference_server_cpus_per_gpu: 15
inference_server_mem_per_gpu: 153600
trainer_cpus_per_gpu: 15
trainer_mem_per_gpu: 153600
slurm:
mount: /storage:/storage
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif

View File

@ -41,7 +41,7 @@ actor:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-6
lr: 1e-5
weight_decay: 0.01
beta1: 0.9
beta2: 0.999

View File

@ -1,5 +1,5 @@
import os
import re
import shutil
import sys
import torch
@ -18,7 +18,9 @@ from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.rlvr import RLVRWorkflow
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("GSM8K grpo")
def process_gsm8k_rl_dataset(dataset: Dataset):
@ -36,54 +38,22 @@ def get_gsm8k_dataset(split, rank, world_size):
return process_gsm8k_rl_dataset(dataset)
# Adapted from verl.
def extract_solution(solution_str, method="strict") -> str | None:
assert method in ["strict", "flexible"]
final_answer = None
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
if len(solutions) == 0:
final_answer = None
else:
# take the last solution
final_answer = solutions[-1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
from realhf.impl.dataset.math_parser import extract_answer
from realhf.impl.dataset.math_parser import process_results
sol = extract_answer(completions, data_name="math")
ans = extract_solution(solution_str=answer, method="strict")
if sol is None:
return 0
if ans is None:
return 0
return int(sol.strip() == ans.strip())
return int(process_results(completions, answer)[0])
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
def main(args):
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
@ -133,6 +103,9 @@ def main_grpo():
gconfig=config.gconfig,
tokenizer=tokenizer,
enable_thinking=False,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
)
# Run training.
@ -190,13 +163,14 @@ def main_grpo():
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
path = os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step),
),
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
@ -206,6 +180,7 @@ def main_grpo():
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
shutil.rmtree(path, ignore_errors=True)
dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
@ -253,4 +228,4 @@ def main_grpo():
if __name__ == "__main__":
main_grpo()
main(sys.argv[1:])

View File

@ -35,8 +35,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
return process_gsm8k_sft_dataset(dataset, tokenizer)
def main_sft():
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
def main(args):
config, _ = load_expr_config(args, SFTConfig)
config: SFTConfig
rank = int(os.getenv("RANK"))
@ -121,4 +121,4 @@ def main_sft():
if __name__ == "__main__":
main_sft()
main(sys.argv[1:])