mirror of https://github.com/inclusionAI/AReaL
[Feature] [lite] Merge from internal dev repo (#189)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * format * format * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
f68a4f677d
commit
18f8a056b6
|
@ -676,20 +676,6 @@ class ClusterSpecConfig:
|
|||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||
},
|
||||
)
|
||||
gpu_type: str = field(
|
||||
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
|
||||
)
|
||||
mount: str = field(
|
||||
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
||||
)
|
||||
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
|
||||
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
|
||||
gpu_infer_image: str = field(
|
||||
default="", metadata={"help": "slurm image for LLM inference."}
|
||||
)
|
||||
node_name_prefix: str = field(
|
||||
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
|
@ -725,6 +711,72 @@ class DatasetConfig:
|
|||
drop_last: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlurmLauncherConfig:
|
||||
"""Configuration for launching the SGLang server with Slurm."""
|
||||
|
||||
srun_additional_args: str = field(
|
||||
default="--overlap --mpi=pmi2 -K --chdir $PWD",
|
||||
metadata={"help": "Additional arguments to pass to the srun command."},
|
||||
)
|
||||
container_type: str = field(
|
||||
default="apptainer",
|
||||
metadata={
|
||||
"help": "Type of containers used in slurm",
|
||||
"choices": ["apptainer", "none"],
|
||||
},
|
||||
)
|
||||
mount: str = field(
|
||||
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
||||
)
|
||||
trainer_image: str = field(
|
||||
default="", metadata={"help": "slurm image for trainers."}
|
||||
)
|
||||
inference_server_image: str = field(
|
||||
default="", metadata={"help": "slurm image for LLM inference."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LauncherConfig:
|
||||
"""Configuration for launching the SGLang server."""
|
||||
|
||||
inference_server_cpus_per_gpu: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
|
||||
)
|
||||
inference_server_mem_per_gpu: int = field(
|
||||
default=32 * 1024,
|
||||
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
|
||||
)
|
||||
trainer_cpus_per_gpu: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of CPUs allocated per GPU for training. "},
|
||||
)
|
||||
trainer_mem_per_gpu: int = field(
|
||||
default=32 * 1024,
|
||||
metadata={"help": "Memory allocated per GPU for training in MB. "},
|
||||
)
|
||||
inference_server_env_vars: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Environment variables for inference server, seperated by commas. "
|
||||
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||
},
|
||||
)
|
||||
trainer_env_vars: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Environment variables for training, seperated by commas. "
|
||||
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||
},
|
||||
)
|
||||
slurm: SlurmLauncherConfig = field(
|
||||
default_factory=SlurmLauncherConfig,
|
||||
metadata={"help": "Slurm launcher configuration."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExperimentConfig:
|
||||
# NOTE: we need this unified config class because different experiments
|
||||
|
@ -742,12 +794,6 @@ class BaseExperimentConfig:
|
|||
default_factory=ClusterSpecConfig,
|
||||
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=1, metadata={"help": "Number of nodes for experiment."}
|
||||
)
|
||||
n_gpus_per_node: int = field(
|
||||
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
||||
)
|
||||
allocation_mode: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
|
@ -785,6 +831,7 @@ class BaseExperimentConfig:
|
|||
|
||||
server_only: bool = False
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
launcher: LauncherConfig = field(default_factory=LauncherConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -266,6 +266,7 @@ class BaseHFEngine(TrainEngine):
|
|||
|
||||
# Scale loss for accumulation
|
||||
# Revert gradient averaging across dp ranks
|
||||
# FIXME: should be DP size
|
||||
loss_scale *= self.world_size
|
||||
|
||||
loss *= loss_scale
|
||||
|
@ -286,8 +287,6 @@ class BaseHFEngine(TrainEngine):
|
|||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import dis
|
||||
import gc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -14,14 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
|
@ -232,6 +222,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
self.model.set_requires_gradient_sync(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
|
@ -258,8 +249,6 @@ class FSDPEngine(BaseHFEngine):
|
|||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
|
|
|
@ -58,15 +58,9 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
|||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
m = loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]]
|
||||
logp = logprobs[cu_seqlens[i] : cu_seqlens[i + 1]]
|
||||
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (m.count_nonzero())
|
||||
|
||||
## Loggin stats
|
||||
stats_tracker.denominator(
|
||||
|
|
|
@ -0,0 +1,410 @@
|
|||
import getpass
|
||||
import importlib.util
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import ray
|
||||
import ray.exceptions
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from arealite.api.cli_args import (
|
||||
ClusterSpecConfig,
|
||||
LauncherConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import (
|
||||
get_env_vars,
|
||||
validate_config_for_distributed_launcher,
|
||||
wait_sglang_server_addrs,
|
||||
)
|
||||
from arealite.utils.ray import get_placement_group_master_ip_and_port
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobState
|
||||
|
||||
logger = logging.getLogger("RayLauncher")
|
||||
|
||||
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
|
||||
DEFAULT_MAIN_FUNC_NAME = "main"
|
||||
|
||||
|
||||
def run_func(file_path, function_name, *args, **kwargs):
|
||||
# Convert the file path to a module name
|
||||
module_name = file_path.replace("/", "_").replace(".", "_")
|
||||
|
||||
# Load the module from file path
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Get the function and execute it
|
||||
try:
|
||||
function = getattr(module, function_name)
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Function '{function_name}' not found in module '{module_name}'. "
|
||||
f"Please ensure the name of the main function in your entry point "
|
||||
f"is '{function_name}'."
|
||||
) from e
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
class RayLauncher:
|
||||
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
# job_name to ray future
|
||||
self.jobs = {}
|
||||
|
||||
@property
|
||||
def run_name(self):
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
job_name: str,
|
||||
file_path: str,
|
||||
func_name: str,
|
||||
args: List[str], # arguments to pass to the function
|
||||
gpus: int,
|
||||
cpus: int,
|
||||
mem: int, # MB
|
||||
env_vars: Optional[Dict] = None,
|
||||
placement_group: Optional[PlacementGroup] = None,
|
||||
bundle_index: int = -1,
|
||||
kwargs: Optional[
|
||||
Dict[str, str]
|
||||
] = None, # keyword arguments to pass to the function
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars=env_vars or dict(),
|
||||
)
|
||||
scheduling_strategy = (
|
||||
PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_index,
|
||||
placement_group_capture_child_tasks=True,
|
||||
)
|
||||
if placement_group is not None
|
||||
else "DEFAULT"
|
||||
)
|
||||
future = ray.remote(
|
||||
num_cpus=cpus,
|
||||
num_gpus=gpus,
|
||||
memory=mem * 1024 * 1024, # Convert MB to bytes
|
||||
runtime_env=runtime_env,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
)(run_func).remote(file_path, func_name, *args, **kwargs)
|
||||
self.jobs[job_name] = future
|
||||
return future
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
file_path: str,
|
||||
func_name: str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
list_args: List[List],
|
||||
gpus_per_task: int,
|
||||
cpus_per_task: int,
|
||||
mem_per_task: int, # MB
|
||||
list_kwargs: List[Dict] | None = None,
|
||||
env_vars: Optional[Dict] = None,
|
||||
amend_torch_dist_env: bool = False,
|
||||
):
|
||||
"""Submit an array of jobs to Ray with ray placement groups.
|
||||
|
||||
Note: Here we use `ray.remote` instead of `ray job submit` since `ray job submit`
|
||||
does not support placement groups, and can not specify which node to run the job on.
|
||||
Therefore we could not know the IP address of jobs for torch distributed initialization.
|
||||
"""
|
||||
|
||||
if count % nodes != 0:
|
||||
raise ValueError(
|
||||
f"Count {count} is not divisible by nodes {nodes}. "
|
||||
"Please ensure that count is a multiple of nodes."
|
||||
)
|
||||
assert (
|
||||
len(list_args) == count
|
||||
), f"Length of list_args {len(list_args)} does not match count {count}."
|
||||
if list_kwargs is not None:
|
||||
assert (
|
||||
len(list_kwargs) == count
|
||||
), f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
|
||||
|
||||
tasks_per_node = count // nodes
|
||||
gpus_per_node = gpus_per_task * tasks_per_node
|
||||
cpus_per_node = cpus_per_task * tasks_per_node
|
||||
mem_per_node = mem_per_task * tasks_per_node
|
||||
|
||||
placement_group = ray.util.placement_group(
|
||||
bundles=[
|
||||
{
|
||||
"CPU": cpus_per_node,
|
||||
"GPU": gpus_per_node,
|
||||
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
|
||||
}
|
||||
]
|
||||
* nodes,
|
||||
strategy="STRICT_SPREAD",
|
||||
)
|
||||
try:
|
||||
ray.get(placement_group.ready(), timeout=30)
|
||||
except ray.exceptions.GetTimeoutError as e:
|
||||
logger.error(
|
||||
"Ray placement group timeout, please check if the resource requirement "
|
||||
"for your experiment exceeds the available resources in the cluster. \n"
|
||||
f"ray.nodes(): {ray.nodes()} \n"
|
||||
f"Placement Group bundles: "
|
||||
f"cpus_per_node={cpus_per_node}, gpus_per_node={gpus_per_node}, "
|
||||
f"mem_per_node={mem_per_node}MB, nodes={nodes}"
|
||||
)
|
||||
raise e
|
||||
|
||||
if amend_torch_dist_env:
|
||||
host_ip, port = get_placement_group_master_ip_and_port(placement_group)
|
||||
logger.info(
|
||||
f"Amend torch distributed env vars: "
|
||||
f"MASTER_ADDR={host_ip}, PORT={port}"
|
||||
)
|
||||
|
||||
futures = []
|
||||
for i in range(count):
|
||||
args = list_args[i]
|
||||
kwargs = list_kwargs[i] if list_kwargs is not None else {}
|
||||
|
||||
# manage environment variables
|
||||
env_vars = env_vars or {}
|
||||
if "CUDA_VISIBLE_DEVICES" in env_vars:
|
||||
logger.warning(
|
||||
"Setting CUDA_VISIBLE_DEVICES before running ray jobs may result in unexpected behavior."
|
||||
)
|
||||
|
||||
node_id = i // tasks_per_node
|
||||
_env_vars = {
|
||||
**env_vars,
|
||||
}
|
||||
|
||||
if amend_torch_dist_env:
|
||||
assert gpus_per_task == 1
|
||||
# NOTE: Here we only provide environment variables for torch distributed
|
||||
# initialization, and LOCAL_RANK for torch.device.
|
||||
# Other environment variables automatically set by torchrun are not set, and
|
||||
# they should be never accessed in trainer code.
|
||||
_env_vars.update(
|
||||
{
|
||||
"RANK": str(i),
|
||||
"WORLD_SIZE": str(count),
|
||||
# Ray will automatically isolate CUDA_VISIBLE_DEVICES for each GPU
|
||||
"LOCAL_RANK": "0",
|
||||
"MASTER_ADDR": str(host_ip),
|
||||
"MASTER_PORT": str(port),
|
||||
}
|
||||
)
|
||||
future = self.submit(
|
||||
job_name=f"{job_name}:{i}",
|
||||
file_path=file_path,
|
||||
func_name=func_name,
|
||||
args=args,
|
||||
gpus=gpus_per_task,
|
||||
cpus=cpus_per_task,
|
||||
mem=mem_per_task,
|
||||
env_vars=_env_vars,
|
||||
placement_group=placement_group,
|
||||
bundle_index=node_id,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
return futures
|
||||
|
||||
def stop(self, job_name: str, force: bool = False):
|
||||
"""Stop a job by name."""
|
||||
if job_name in self.jobs:
|
||||
future = self.jobs[job_name]
|
||||
try:
|
||||
ray.cancel(future, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel job {job_name}: {e}")
|
||||
return
|
||||
self.jobs.pop(job_name, None)
|
||||
logger.info(f"Job {job_name} stopped.")
|
||||
else:
|
||||
logger.warning(f"Job {job_name} not found in running jobs.")
|
||||
|
||||
def stop_all(self, force: bool = False):
|
||||
"""Stop all jobs."""
|
||||
for job_name in list(self.jobs.keys()):
|
||||
self.stop(job_name, force=force)
|
||||
logger.info("All jobs stopped.")
|
||||
self.jobs.clear()
|
||||
|
||||
def wait(
|
||||
self, check_status=(JobState.FAILED,), remove_status=(JobState.COMPLETED,)
|
||||
):
|
||||
"""Check every RAY_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
|
||||
If a ray job returns, its status changes to JobState.COMPLETED.
|
||||
If a ray job failed, its status changes to JobState.FAILED.
|
||||
If any job is in check_status, stop all jobs at once.
|
||||
If any job is in remove status, remove them from job list.
|
||||
Return if all jobs are removed from job list, or some job is in check status.
|
||||
"""
|
||||
for status in list(check_status) + list(remove_status):
|
||||
assert status in [
|
||||
JobState.COMPLETED,
|
||||
JobState.FAILED,
|
||||
], "In RayLauncher.wait, we only check completed or failed jobs."
|
||||
logger.info(f"Waiting for {len(self.jobs)} jobs.")
|
||||
while self.jobs:
|
||||
job_status = {}
|
||||
for job_name, future in list(self.jobs.items()):
|
||||
try:
|
||||
r = ray.get(future, timeout=0.1)
|
||||
logger.info(f"Job {job_name} completed with result: {r}")
|
||||
job_status[job_name] = JobState.COMPLETED
|
||||
except ray.exceptions.RayTaskError as e:
|
||||
logger.error(f"Job {job_name} failed with error: {e}.")
|
||||
job_status[job_name] = JobState.FAILED
|
||||
except ray.exceptions.GetTimeoutError:
|
||||
continue
|
||||
|
||||
for job_name, status in job_status.items():
|
||||
if status in check_status:
|
||||
logger.info(f"Job {job_name} is {status}, stopping all jobs.")
|
||||
self.stop_all(force=True)
|
||||
return
|
||||
if status in remove_status:
|
||||
logger.info(f"Job {job_name} is {status}, removed.")
|
||||
self.jobs.pop(job_name)
|
||||
|
||||
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
|
||||
def ray_main():
|
||||
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>]
|
||||
ray.init()
|
||||
config, config_file = parse_cli_args(sys.argv[2:])
|
||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
validate_config_for_distributed_launcher(config)
|
||||
|
||||
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(
|
||||
experiment_name=config.experiment_name, trial_name=config.trial_name
|
||||
)
|
||||
)
|
||||
|
||||
n_nodes = config.cluster.n_nodes
|
||||
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||
launcher = RayLauncher(
|
||||
experiment_name=config.experiment_name,
|
||||
trial_name=config.trial_name,
|
||||
fileroot=config.cluster.fileroot,
|
||||
)
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
sglang_addrs = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
sglang_tp_size = allocation_mode.gen_tp_size
|
||||
n_sglang_servers = allocation_mode.gen_dp_size
|
||||
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||
|
||||
base_seed = config.sglang.random_seed
|
||||
sglang_args_list = [
|
||||
[sys.argv[2:] + [f"sglang.random_seed={base_seed + i}"]]
|
||||
for i in range(n_sglang_servers)
|
||||
]
|
||||
sglang_entry_point = str(
|
||||
pathlib.Path(__file__).resolve().parent.joinpath("sglang_server.py")
|
||||
)
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
file_path=sglang_entry_point,
|
||||
func_name=DEFAULT_MAIN_FUNC_NAME,
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
list_args=sglang_args_list,
|
||||
gpus_per_task=sglang_tp_size,
|
||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||
* sglang_tp_size,
|
||||
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
# Get SGLang server addresses via name_resolve
|
||||
try:
|
||||
sglang_addrs = wait_sglang_server_addrs(
|
||||
config.experiment_name,
|
||||
config.trial_name,
|
||||
n_sglang_servers,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
launcher.stop_all(force=True)
|
||||
raise e
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
trainer_entry_point = sys.argv[1]
|
||||
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
|
||||
trainer_args_list = [[sys.argv[2:]] for _ in range(n_trainer_processes)]
|
||||
if not config.server_only:
|
||||
# In ray, we launch trainer in the granularity of processes (1 GPU per process)
|
||||
# We amend environment variable similar to torchrun to ensure correct initialization of
|
||||
# torch distributed.
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
file_path=trainer_entry_point,
|
||||
func_name=DEFAULT_MAIN_FUNC_NAME,
|
||||
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
|
||||
nodes=trainer_n_nodes,
|
||||
list_args=trainer_args_list,
|
||||
gpus_per_task=1,
|
||||
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
|
||||
mem_per_task=config.launcher.trainer_mem_per_gpu,
|
||||
env_vars=dict(
|
||||
**get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.trainer_env_vars,
|
||||
),
|
||||
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||
),
|
||||
amend_torch_dist_env=True,
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(check_status=(JobState.COMPLETED, JobState.FAILED))
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all(force=True)
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.ray \
|
||||
# <entry_point> --config <config_path> [<additional_args>] \
|
||||
# launcher.ray.main_func_name=<main_func_name_in_entry_point>
|
||||
ray_main()
|
|
@ -0,0 +1,197 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import requests
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
ClusterSpecConfig,
|
||||
NameResolveConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import TRITON_CACHE_PATH
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("SGLangServer Wrapper")
|
||||
|
||||
|
||||
def execute_shell_command(command: str) -> subprocess.Popen:
|
||||
"""
|
||||
Execute a shell command and return its process handle.
|
||||
"""
|
||||
# Replace newline continuations and split the command string.
|
||||
command = command.replace("\\\n", " ").replace("\\", " ")
|
||||
parts = command.split()
|
||||
_env = os.environ.copy()
|
||||
# To avoid DirectoryNotEmpty error caused by triton
|
||||
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
|
||||
unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4()))
|
||||
_env["TRITON_CACHE_PATH"] = unique_triton_cache_path
|
||||
return subprocess.Popen(
|
||||
parts,
|
||||
text=True,
|
||||
env=_env,
|
||||
stdout=sys.stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path:
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
|
||||
|
||||
def launch_server_cmd(command: str):
|
||||
"""
|
||||
Launch the server using the given command.
|
||||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_patch()
|
||||
process = execute_shell_command(command)
|
||||
return process
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the server
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/v1/models",
|
||||
headers={"Authorization": "Bearer None"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
time.sleep(5)
|
||||
break
|
||||
|
||||
if timeout and time.time() - start_time > timeout:
|
||||
raise TimeoutError("Server did not become ready within timeout period")
|
||||
except requests.exceptions.RequestException:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class SGLangServerWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
trial_name: str,
|
||||
sglang_config: SGLangConfig,
|
||||
tp_size: int,
|
||||
n_gpus_per_node: int,
|
||||
):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.config = sglang_config
|
||||
self.tp_size = tp_size
|
||||
self.server_process = None
|
||||
self.n_gpus_per_node = n_gpus_per_node
|
||||
|
||||
def run(self):
|
||||
gpus_per_server = len(os.getenv("CUDA_VISIBLE_DEVICES").split(","))
|
||||
server_local_idx = (
|
||||
int(os.getenv("CUDA_VISIBLE_DEVICES").split(",")[0]) // gpus_per_server
|
||||
)
|
||||
n_servers_per_node = max(1, self.n_gpus_per_node // gpus_per_server)
|
||||
ports_per_server = 40000 // n_servers_per_node
|
||||
port_range = (
|
||||
server_local_idx * ports_per_server + 10000,
|
||||
(server_local_idx + 1) * ports_per_server + 10000,
|
||||
)
|
||||
server_port, dist_init_port = find_free_ports(2, port_range)
|
||||
|
||||
dist_init_addr = f"localhost:{dist_init_port}"
|
||||
host_ip = gethostip()
|
||||
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
self.config,
|
||||
tp_size=self.tp_size,
|
||||
base_gpu_id=0,
|
||||
host=host_ip,
|
||||
port=server_port,
|
||||
dist_init_addr=dist_init_addr,
|
||||
)
|
||||
self.server_process = launch_server_cmd(cmd)
|
||||
wait_for_server(f"http://{host_ip}:{server_port}")
|
||||
|
||||
name = names.gen_servers(self.experiment_name, self.trial_name)
|
||||
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")
|
||||
|
||||
logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}")
|
||||
return_code = self.server_process.wait()
|
||||
logger.info(
|
||||
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
logger.info("Terminating SGLang server process...")
|
||||
self.server_process.terminate()
|
||||
self.server_process.wait()
|
||||
logger.info("SGLang server process terminated.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
config, _ = parse_cli_args(argv)
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
config.cluster.name_resolve = to_structured_cfg(
|
||||
config.cluster.name_resolve, NameResolveConfig
|
||||
)
|
||||
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG
|
||||
tp_size = allocation_mode.gen_tp_size
|
||||
|
||||
sglang_server = SGLangServerWrapper(
|
||||
config.experiment_name,
|
||||
config.trial_name,
|
||||
config.sglang,
|
||||
tp_size,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
)
|
||||
sglang_server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
|
@ -0,0 +1,528 @@
|
|||
import getpass
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from arealite.api.cli_args import (
|
||||
ClusterSpecConfig,
|
||||
LauncherConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import (
|
||||
get_env_vars,
|
||||
validate_config_for_distributed_launcher,
|
||||
wait_sglang_server_addrs,
|
||||
)
|
||||
from arealite.utils.slurm import (
|
||||
APPTAINER_CMD_TEMPLATE,
|
||||
SBATCH_SCRIPT_TEMPLATE,
|
||||
SRUN_CMD_TEMPLATE,
|
||||
cancel_jobs,
|
||||
query_jobs,
|
||||
)
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("SlurmLauncher")
|
||||
|
||||
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
class SlurmLauncher:
|
||||
def __init__(
|
||||
self, experiment_name: str, trial_name: str, fileroot: str, container_type: str
|
||||
):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
self.container_type = container_type
|
||||
|
||||
# slurm_job_id -> JobInfo
|
||||
self.jobs: Dict[int, JobInfo] = {}
|
||||
self.job_names = []
|
||||
|
||||
@property
|
||||
def run_name(self) -> str:
|
||||
"""Returns the run name of this launcher."""
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def slurm_name(self, job_name: str) -> str:
|
||||
"""Returns the slurm name of a job."""
|
||||
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def sbatch_path_of(self, job_name: str) -> str:
|
||||
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(sbatch_path, exist_ok=True)
|
||||
return os.path.join(sbatch_path, f"{job_name}.sh")
|
||||
|
||||
def submit(self, job_name, cmd, **kwargs):
|
||||
"""Submits and launch a job with SBATCH.
|
||||
|
||||
Args:
|
||||
cmd (str or List[str]): The core command to be executed.
|
||||
"""
|
||||
return self.submit_array(job_name, cmd, count=1, **kwargs)
|
||||
|
||||
def find_job_id(self, job_name: str):
|
||||
job_name = self.slurm_name(job_name)
|
||||
for job_id, job_info in self.jobs.items():
|
||||
if job_info.name == job_name:
|
||||
return job_id
|
||||
return None
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: List[str] | str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
n_gpus_per_node: int,
|
||||
cpus_per_task: int,
|
||||
mem_per_task: int, # MB
|
||||
container_image: str,
|
||||
srun_additional_args: str = "",
|
||||
container_mounts: Optional[str] = None,
|
||||
env_vars: Optional[Dict] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
):
|
||||
"""Submits and launch a job array with SBATCH.
|
||||
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||
|
||||
Args:
|
||||
job_name (str): The job name of the job array. The actual slurm name will be
|
||||
`<experiment_name>_<trial_name>:<job_name>`.
|
||||
cmd (str or List[str]): The core command to be executed.
|
||||
count (int): The number of jobs in the array.
|
||||
"""
|
||||
assert job_name not in self.job_names, (
|
||||
f"Job {job_name} is already submitted. "
|
||||
"Please use a different job name or stop the existing job."
|
||||
)
|
||||
if isinstance(cmd, str):
|
||||
cmd = [cmd]
|
||||
assert len(cmd) == count, (
|
||||
f"Command length {len(cmd)} does not match the job count {count}. "
|
||||
"Please provide a command for each job in the array."
|
||||
)
|
||||
assert count % nodes == 0, (
|
||||
f"Job count {count} must be divisible by the number of nodes {nodes}. "
|
||||
"Please adjust the job count or the number of nodes."
|
||||
)
|
||||
ntasks_per_node = count // nodes
|
||||
assert n_gpus_per_node % ntasks_per_node == 0, (
|
||||
"GPUs must be evenly distributed across tasks. "
|
||||
f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}."
|
||||
)
|
||||
|
||||
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
|
||||
mem_per_node = (
|
||||
mem_per_task * count // nodes + 1024 * 10
|
||||
) # make sure slurm does not run out of resources
|
||||
|
||||
sbatch_options = [
|
||||
f"--job-name={self.slurm_name(job_name)}",
|
||||
f"--output={self.log_path_of(job_name)}",
|
||||
"--open-mode=append",
|
||||
"--no-requeue",
|
||||
f"--nodes={nodes}-{nodes}",
|
||||
f"--ntasks-per-node={ntasks_per_node}",
|
||||
f"--gres=gpu:{n_gpus_per_node}",
|
||||
f"--cpus-per-task={cpus_per_task}",
|
||||
f"--mem={mem_per_node}M",
|
||||
]
|
||||
|
||||
if nodelist:
|
||||
sbatch_options.append(f"--nodelist={nodelist}")
|
||||
if exclude:
|
||||
sbatch_options.append(f"--exclude={exclude}")
|
||||
|
||||
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
|
||||
|
||||
if env_vars is None:
|
||||
env_vars = dict()
|
||||
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
|
||||
assert (
|
||||
"CUDA_VISIBLE_DEVICES" not in env_vars
|
||||
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||
|
||||
srun_cmds = []
|
||||
for i in range(count):
|
||||
# resolve CUDA_VISIBLE_DEVICES for each task
|
||||
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
|
||||
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
|
||||
node_id = i // ntasks_per_node
|
||||
_env_vars = {
|
||||
**env_vars,
|
||||
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||
),
|
||||
}
|
||||
# Prepare the command for each job in the array
|
||||
job_cmd = cmd[i]
|
||||
|
||||
if self.container_type == "apptainer":
|
||||
env_string = " ".join(
|
||||
"--env {}={}".format(k, v) for k, v in _env_vars.items()
|
||||
)
|
||||
apptainer_cmd = APPTAINER_CMD_TEMPLATE.format(
|
||||
container_mounts=container_mounts or "",
|
||||
container_env_strings=env_string,
|
||||
container_image=container_image,
|
||||
cmd=job_cmd,
|
||||
)
|
||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
additional_args=srun_additional_args,
|
||||
nodes=1,
|
||||
ntasks=1,
|
||||
node_id=node_id,
|
||||
n_gpus_per_node=n_gpus_per_task,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
cmd=apptainer_cmd,
|
||||
)
|
||||
elif self.container_type == "none":
|
||||
env_string = "--export=" + ",".join(
|
||||
"{}={}".format(k, v) for k, v in _env_vars.items()
|
||||
)
|
||||
srun_additional_args = srun_additional_args + " " + env_string
|
||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
additional_args=srun_additional_args,
|
||||
nodes=1,
|
||||
ntasks=1,
|
||||
node_id=node_id,
|
||||
n_gpus_per_node=n_gpus_per_task,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
cmd=job_cmd,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported container type: {self.container_type}. "
|
||||
"Supported types are 'apptainer' and 'none'."
|
||||
)
|
||||
srun_cmds.append(srun_cmd)
|
||||
|
||||
srun_cmds = "\n".join(srun_cmds)
|
||||
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
|
||||
sbatch_options=sbatch_options_str,
|
||||
srun_additional_args=srun_additional_args,
|
||||
srun_cmds=srun_cmds,
|
||||
)
|
||||
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
|
||||
with open(sbatch_file_path, "w") as f:
|
||||
f.write(sbatch_script)
|
||||
|
||||
# Submit the job
|
||||
try:
|
||||
output = (
|
||||
subprocess.check_output(["sbatch", sbatch_file_path])
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
logger.info(
|
||||
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(
|
||||
f"Failed to submit job {self.slurm_name(job_name)}. "
|
||||
f"For debugging, please make sure your sbatch command works "
|
||||
f"and check generated sbatch file on {sbatch_file_path}."
|
||||
)
|
||||
logger.error(f"Error message: {e}")
|
||||
return
|
||||
|
||||
match = re.search(r"Submitted batch job (\d+)", output)
|
||||
slurm_job_id = int(match.group(1)) if match else None
|
||||
if slurm_job_id is None:
|
||||
logger.warning(
|
||||
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
|
||||
f"sbatch output: {output}"
|
||||
)
|
||||
return
|
||||
|
||||
assert isinstance(slurm_job_id, int)
|
||||
self.jobs[slurm_job_id] = JobInfo(
|
||||
name=self.slurm_name(job_name),
|
||||
state=JobState.PENDING,
|
||||
slurm_id=slurm_job_id,
|
||||
)
|
||||
self._update_all()
|
||||
|
||||
def stop(self, job_name, force=False):
|
||||
"""Stops a running job.
|
||||
|
||||
Raises exception if there is no such job, but passes if the job
|
||||
has stopped either successfully or not.
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to stop.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
"""
|
||||
signal = "SIGKILL" if force else "SIGTERM"
|
||||
job_id = self.find_job_id(job_name)
|
||||
if not job_id:
|
||||
return
|
||||
return cancel_jobs(slurm_ids=[job_id], signal=signal)
|
||||
|
||||
def stop_all(self, force=False):
|
||||
"""Stops all running jobs."""
|
||||
signal = "SIGKILL" if force else "SIGTERM"
|
||||
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
|
||||
|
||||
def find(self, job_name) -> JobInfo | None:
|
||||
"""Gets the status of a job of this job.
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to find.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
|
||||
Returns:
|
||||
A JobInfo if the job is found, or None otherwise.
|
||||
"""
|
||||
self._update_all()
|
||||
job_id = self.find_job_id(job_name)
|
||||
return self.jobs[job_id] if job_id else None
|
||||
|
||||
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
|
||||
"""Finds jobs.
|
||||
|
||||
Args:
|
||||
job_name_regex: job name regex.
|
||||
|
||||
Returns:
|
||||
A list of found JobInfo.
|
||||
"""
|
||||
self._update_all()
|
||||
infos = []
|
||||
for r in self.jobs.values():
|
||||
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
|
||||
if re.fullmatch(job_name_regex, job_name):
|
||||
infos.append(r)
|
||||
return infos
|
||||
|
||||
def _find_job_with_status(
|
||||
self,
|
||||
status: List[JobState],
|
||||
) -> List[JobInfo]:
|
||||
"""Finds jobs with the given status.
|
||||
|
||||
Args:
|
||||
status: A list of JobState to filter jobs.
|
||||
|
||||
Returns:
|
||||
A list of JobInfo with the given status.
|
||||
"""
|
||||
self._update_all()
|
||||
return [r for r in self.jobs.values() if r.state in status]
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout=None,
|
||||
check_status: Tuple[JobState, ...] = (
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
),
|
||||
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
|
||||
update=False,
|
||||
):
|
||||
"""Waits until all jobs submitted via this client instance finish."""
|
||||
# begin wait
|
||||
deadline = None if timeout is None else time.time() + timeout
|
||||
|
||||
num_jobs_left = len(self.jobs)
|
||||
left = list(self.jobs.keys())
|
||||
logger.info(
|
||||
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
|
||||
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||
)
|
||||
while len(left) > 0:
|
||||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
|
||||
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||
)
|
||||
self._update_all()
|
||||
left = list(self.jobs.keys())
|
||||
for slurm_id in list(left):
|
||||
slurm_info = self.jobs[slurm_id]
|
||||
if slurm_info.slurm_id is None:
|
||||
continue
|
||||
if slurm_info.state in check_status:
|
||||
raise JobException(
|
||||
run_name=self.run_name,
|
||||
worker_type=slurm_info.name,
|
||||
host=slurm_info.host,
|
||||
reason=slurm_info.state,
|
||||
)
|
||||
if slurm_info.state in remove_status:
|
||||
logger.info(
|
||||
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
|
||||
)
|
||||
left.remove(slurm_id)
|
||||
if update:
|
||||
self.jobs.pop(slurm_info.slurm_id)
|
||||
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
def _update_all(self):
|
||||
"""Updates the status of all jobs."""
|
||||
try:
|
||||
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
|
||||
for slurm_info in slurm_infos:
|
||||
assert slurm_info.slurm_id is not None
|
||||
self.jobs[slurm_info.slurm_id] = slurm_info
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(
|
||||
"Calling squeue failed. Check slurm manually if you continue to see this warning."
|
||||
)
|
||||
|
||||
|
||||
def slurm_main():
|
||||
config, _ = parse_cli_args(sys.argv[2:])
|
||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
validate_config_for_distributed_launcher(config)
|
||||
|
||||
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(
|
||||
experiment_name=config.experiment_name, trial_name=config.trial_name
|
||||
)
|
||||
)
|
||||
|
||||
n_nodes = config.cluster.n_nodes
|
||||
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||
|
||||
launcher = SlurmLauncher(
|
||||
experiment_name=config.experiment_name,
|
||||
trial_name=config.trial_name,
|
||||
fileroot=config.cluster.fileroot,
|
||||
container_type=config.launcher.slurm.container_type,
|
||||
)
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
sglang_addrs = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
sglang_tp_size = allocation_mode.gen_tp_size
|
||||
n_sglang_servers = allocation_mode.gen_dp_size
|
||||
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||
|
||||
base_seed = config.sglang.random_seed
|
||||
sglang_server_cmd_template = f"python3 -m arealite.launcher.sglang_server {' '.join(sys.argv[2:])} sglang.random_seed={{seed}}"
|
||||
for i in range(n_sglang_servers):
|
||||
sglang_cmd = sglang_server_cmd_template.format(
|
||||
seed=base_seed + i,
|
||||
)
|
||||
sglang_cmds.append(sglang_cmd)
|
||||
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
cmd=sglang_cmds,
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||
* sglang_tp_size,
|
||||
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||
srun_additional_args=config.launcher.slurm.srun_additional_args,
|
||||
container_image=config.launcher.slurm.inference_server_image,
|
||||
container_mounts=config.launcher.slurm.mount,
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
# Get SGLang server addresses by name resolve
|
||||
try:
|
||||
sglang_addrs = wait_sglang_server_addrs(
|
||||
config.experiment_name,
|
||||
config.trial_name,
|
||||
n_sglang_servers,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
launcher.stop_all(force=True)
|
||||
raise e
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
# Here $head_node_ip is the IP address of the first node in the job array.
|
||||
# $trainer_port is a free port on the head node.
|
||||
# Both of them are obtained in by the SBATCH script.
|
||||
trainer_cmd_template = (
|
||||
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||
f"--master-addr $head_node_ip --master-port $trainer_port {' '.join(sys.argv[1:])}"
|
||||
)
|
||||
|
||||
trainer_cmds = []
|
||||
for i in range(trainer_n_nodes):
|
||||
# In slurm, we launch trainer in the granularity of nodes with torchrun command.
|
||||
trainer_cmds.append(
|
||||
trainer_cmd_template.format(
|
||||
nnodes=trainer_n_nodes,
|
||||
nproc_per_node=config.cluster.n_gpus_per_node,
|
||||
node_rank=i,
|
||||
)
|
||||
)
|
||||
|
||||
if not config.server_only:
|
||||
# launch trainers
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
cmd=trainer_cmds,
|
||||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=config.launcher.trainer_cpus_per_gpu
|
||||
* config.cluster.n_gpus_per_node,
|
||||
mem_per_task=config.launcher.trainer_mem_per_gpu
|
||||
* config.cluster.n_gpus_per_node,
|
||||
container_image=config.launcher.slurm.trainer_image,
|
||||
srun_additional_args=config.launcher.slurm.srun_additional_args,
|
||||
container_mounts=config.launcher.slurm.mount,
|
||||
env_vars=dict(
|
||||
**get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.trainer_env_vars,
|
||||
),
|
||||
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(
|
||||
check_status=(
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
JobState.COMPLETED,
|
||||
),
|
||||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all(force=True)
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> \
|
||||
# --config <config_path> [<additional_args>]
|
||||
slurm_main()
|
|
@ -2,11 +2,9 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
InferenceEngineConfig,
|
||||
|
@ -18,7 +16,6 @@ from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
|||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.network import find_free_ports
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||
|
|
|
@ -399,6 +399,12 @@ def pad_packed_tensor_dict(
|
|||
padded_data[key] = new_max_seqlen
|
||||
elif torch.is_tensor(value) and value.numel() == total_length:
|
||||
# Pad the tensor to the new total length
|
||||
if key == "position_ids":
|
||||
# transformers will compute flash-attn arguments (e.g., cu_seqlens_q)
|
||||
# according to this position ids.
|
||||
pad = torch.arange(pad_length, dtype=torch.long, device=value.device)
|
||||
padded_tensor = torch.cat([value, pad])
|
||||
else:
|
||||
padded_tensor = torch.nn.functional.pad(
|
||||
value, (0, pad_length), value=pad_value
|
||||
)
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
import getpass
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
logger = logging.getLogger("Launcher Utils")
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton/"
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
BASE_ENVIRONS = {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"PYTHONPATH": str(pathlib.Path(__file__).resolve().parent.parent.parent),
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
"NCCL_IB_GID_INDEX": "3",
|
||||
"NCCL_IB_TIMEOUT": "2",
|
||||
"NCCL_IB_RETRY_CNT": "7",
|
||||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||
|
||||
|
||||
def get_env_vars(
|
||||
cluster_name: str, additional_env_vars: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Returns the environment variables for the cluster."""
|
||||
_additional_env_vars = (
|
||||
dict(item.split("=") for item in additional_env_vars.split(","))
|
||||
if additional_env_vars
|
||||
else dict()
|
||||
)
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **_additional_env_vars}
|
||||
else:
|
||||
return {**BASE_ENVIRONS, **_additional_env_vars}
|
||||
|
||||
|
||||
def wait_sglang_server_addrs(
|
||||
experiment_name: str,
|
||||
trial_name: str,
|
||||
n_sglang_servers: int,
|
||||
):
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
name = names.gen_servers(experiment_name, trial_name)
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
sglang_addrs = name_resolve.get_subtree(name)
|
||||
if len(sglang_addrs) >= n_sglang_servers:
|
||||
logger.info(
|
||||
f"Found {len(sglang_addrs)} SGLang servers: {', '.join(sglang_addrs)}"
|
||||
)
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for SGLang servers to be ready. "
|
||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||
)
|
||||
return sglang_addrs
|
||||
|
||||
|
||||
def validate_config_for_distributed_launcher(config):
|
||||
n_nodes = config.cluster.n_nodes
|
||||
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
assert (
|
||||
allocation_mode.gen_world_size + allocation_mode.train_world_size
|
||||
== n_nodes * n_gpus_per_node
|
||||
), (
|
||||
f"#GPUs required for allocation mode {allocation_mode.gen_world_size + allocation_mode.train_world_size} "
|
||||
f"is not equal to #GPUs in the config {n_nodes * n_gpus_per_node}."
|
||||
)
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
assert (
|
||||
allocation_mode.gen_pp_size == 1
|
||||
), "Pipeline generation in SGLang is not supported for now."
|
||||
assert (
|
||||
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
|
||||
), "Currently only support SGLang TP size less <= #GPUs per node."
|
|
@ -1,48 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return TensorDict()
|
||||
|
||||
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
|
||||
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
assert all("attention_mask" in td for td in tensor_dicts)
|
||||
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
||||
result = {}
|
||||
# Process each key
|
||||
for key in tensor_dicts[0].keys():
|
||||
tensors_to_concat = []
|
||||
for tensor_dict in tensor_dicts:
|
||||
tensor = tensor_dict[key]
|
||||
# Skip 1D tensors like rewards
|
||||
if len(tensor.shape) == 1:
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if current_length < max_length:
|
||||
# Pad tensor to max_length
|
||||
pad_width = max_length - current_length
|
||||
if key == "attention_mask":
|
||||
# Pad attention mask with 0s
|
||||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
|
||||
)
|
||||
tensor = torch.cat([tensor, padding], dim=1)
|
||||
tensors_to_concat.append(tensor)
|
||||
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
return TensorDict(result, batch_size=new_batch_size)
|
|
@ -0,0 +1,23 @@
|
|||
import ray
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
|
||||
|
||||
def get_placement_group_master_ip_and_port(placement_group: PlacementGroup):
|
||||
def _master_ip_and_port():
|
||||
host_ip = gethostip()
|
||||
port = find_free_ports(1, (10000, 60000))[0]
|
||||
return host_ip, port
|
||||
|
||||
future = ray.remote(
|
||||
num_cpus=1,
|
||||
num_gpus=0,
|
||||
memory=10 * 1024 * 1024, # Convert MB to bytes
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=0,
|
||||
),
|
||||
)(_master_ip_and_port).remote()
|
||||
return ray.get(future)
|
|
@ -0,0 +1,154 @@
|
|||
import subprocess
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from realhf.base import logging
|
||||
from realhf.scheduler.client import JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Slurm Utils")
|
||||
|
||||
|
||||
SQUEUE_FIELDS = [
|
||||
"JobID",
|
||||
"State",
|
||||
"SubmitTime",
|
||||
"StartTime",
|
||||
"Name",
|
||||
"NodeList",
|
||||
"UserName",
|
||||
"MaxCPUs",
|
||||
"cpus-per-task",
|
||||
"NumTasks",
|
||||
"tres-alloc",
|
||||
]
|
||||
STATUS_MAPPING = {
|
||||
"RUNNING": JobState.RUNNING,
|
||||
"COMPLETING": JobState.RUNNING,
|
||||
"PENDING": JobState.PENDING,
|
||||
"CANCELLED": JobState.CANCELLED,
|
||||
"FAILED": JobState.FAILED,
|
||||
"COMPLETED": JobState.COMPLETED,
|
||||
"OUT_OF_MEMORY": JobState.FAILED,
|
||||
"DEADLINE": JobState.COMPLETED,
|
||||
"TIMEOUT": JobState.COMPLETED,
|
||||
}
|
||||
|
||||
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||
{sbatch_options}
|
||||
|
||||
# Getting the node names
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
echo nodes=$nodes
|
||||
|
||||
nodes_array=($nodes)
|
||||
echo node_array=$nodes_array
|
||||
|
||||
head_node=${{nodes_array[0]}}
|
||||
echo head_node=$head_node
|
||||
|
||||
# Getting the head node IP address
|
||||
head_node_ip=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||
echo head_node_ip=$head_node_ip
|
||||
|
||||
# Find a free port on the head node
|
||||
# Wonderful linux command to find a random free port (between 10000 and 60000) by deepseek
|
||||
trainer_port=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" bash -c "comm -23 <(seq 10000 60000 | sort) <(ss -tan | awk '{{print $4}}' | cut -d':' -f2 | grep '[0-9]\\{{1,5\\}}' | sort -u) | shuf | head -n 1")
|
||||
echo trainer_port=$trainer_port
|
||||
|
||||
# srun commands
|
||||
{srun_cmds}
|
||||
|
||||
wait
|
||||
"""
|
||||
|
||||
SRUN_CMD_TEMPLATE: str = """srun {additional_args} \\
|
||||
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
|
||||
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
|
||||
{cmd} &
|
||||
"""
|
||||
|
||||
APPTAINER_CMD_TEMPLATE: str = """singularity exec --no-home --writable-tmpfs --nv --pid \\
|
||||
--bind {container_mounts} \\
|
||||
{container_env_strings} \\
|
||||
{container_image} \\
|
||||
{cmd}"""
|
||||
|
||||
|
||||
def cancel_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||
):
|
||||
assert (
|
||||
slurm_names is not None or slurm_ids is not None
|
||||
), "Must specify slurm_names or slurm_ids."
|
||||
assert not (
|
||||
slurm_names and slurm_ids
|
||||
), "Cannot specify both slurm_names and slurm_ids."
|
||||
cmd = ["scancel", "-s", signal]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
elif slurm_ids is not None:
|
||||
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||
subprocess.check_call(cmd)
|
||||
logger.info(
|
||||
f"Cancelled Slurm job with signal {signal}: "
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||
)
|
||||
|
||||
|
||||
def query_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
status: str = "all",
|
||||
delimiter: str = "__PSI__",
|
||||
) -> List[JobInfo]:
|
||||
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
if slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
|
||||
output = (
|
||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||
)
|
||||
rs = []
|
||||
for line in output.split("\n")[1:]:
|
||||
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||
delimiter
|
||||
)
|
||||
rs.append(
|
||||
JobInfo(
|
||||
name=slurm_name,
|
||||
state=STATUS_MAPPING[state],
|
||||
host=nodelist,
|
||||
submit_time=submit_time,
|
||||
start_time=start_time,
|
||||
slurm_id=int(job_id.strip()),
|
||||
)
|
||||
)
|
||||
return rs
|
||||
|
||||
|
||||
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||
return (
|
||||
subprocess.check_output(
|
||||
[
|
||||
"scontrol",
|
||||
"show",
|
||||
"hostnames",
|
||||
nodelist,
|
||||
]
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def get_slurm_host_ip(node: str, srun_addtional_args: str):
|
||||
try:
|
||||
cmd = f"srun {srun_addtional_args} --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(f"Get slurm host IP for node {node} failed.")
|
|
@ -73,7 +73,7 @@ class StatsLogger:
|
|||
)
|
||||
if isinstance(data, Dict):
|
||||
data = [data]
|
||||
log_step = max(global_step, self._last_commit_step)
|
||||
log_step = max(global_step, self._last_commit_step + 1)
|
||||
for i, item in enumerate(data):
|
||||
self.info(f"Stats ({i+1}/{len(data)}):")
|
||||
self.print_stats(item)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
|
@ -18,13 +21,17 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
enable_thinking: bool,
|
||||
dump_dir: str | None = None,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.enable_thinking = enable_thinking
|
||||
self.dump_dir = dump_dir
|
||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||
os.makedirs(self.dump_dir, exist_ok=True)
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
|
@ -39,6 +46,12 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
|
@ -46,13 +59,19 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
prompt=self.tokenizer.decode(input_ids),
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
|
@ -65,4 +84,31 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
if self.dump_dir is not None:
|
||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||
# Get the unique identifier for this prompt
|
||||
qid = None
|
||||
for key in ["query_id", "id", "qid"]:
|
||||
qid = data.get(key, None)
|
||||
if qid is not None:
|
||||
break
|
||||
qid = qid or uuid.uuid4().hex
|
||||
|
||||
# Dump rollout to file
|
||||
with open(
|
||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||
) as f:
|
||||
n_samples = self.gconfig.n_samples
|
||||
for i, (p, c, r, sl) in enumerate(
|
||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||
]
|
||||
)
|
||||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
GRPOConfig,
|
||||
load_expr_config,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, seeding, stats_tracker
|
||||
|
||||
logger = logging.getLogger("boba math")
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
dump_dir: str | None = None,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.dump_dir = dump_dir
|
||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||
os.makedirs(self.dump_dir, exist_ok=True)
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
input_ids = self.tokenizer.encode(data["prompt"])
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
prompt_str = data["prompt"]
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([float(reward)]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
if self.dump_dir is not None:
|
||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||
# Get the unique identifier for this prompt
|
||||
qid = None
|
||||
for key in ["query_id", "id", "qid"]:
|
||||
qid = data.get(key, None)
|
||||
if qid is not None:
|
||||
break
|
||||
qid = qid or uuid.uuid4().hex
|
||||
|
||||
# Dump rollout to file
|
||||
with open(
|
||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||
) as f:
|
||||
n_samples = self.gconfig.n_samples
|
||||
for i, (p, c, r, sl) in enumerate(
|
||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||
]
|
||||
)
|
||||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
|
||||
def get_boba_math_dataset(tokenizer, rank, world_size):
|
||||
dataset = load_dataset(
|
||||
path="json",
|
||||
split="train",
|
||||
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
|
||||
)
|
||||
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
|
||||
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def boba_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
||||
):
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
jobs = []
|
||||
with ProcessPool(max_workers=1) as executor:
|
||||
for sol in solutions:
|
||||
job = executor.schedule(
|
||||
process_results, args=[completions, sol], timeout=15
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
label = 0
|
||||
for job in jobs:
|
||||
try:
|
||||
x = job.result()
|
||||
except TimeoutError:
|
||||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
return label
|
||||
|
||||
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_boba_math_dataset(tokenizer, rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=boba_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
dump_dir=os.path.join(
|
||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||
),
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("recompute logp")
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
log_gpu_stats("compute advantages")
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
path = os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver),
|
||||
"update_weights",
|
||||
str(global_step + 1),
|
||||
)
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=path,
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
rollout.set_version(global_step + 1)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
|
||||
logger.close()
|
||||
rollout.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
actor.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
|
@ -0,0 +1,141 @@
|
|||
experiment_name: lite-boba-math
|
||||
trial_name: run1
|
||||
|
||||
cluster:
|
||||
n_nodes: 16
|
||||
n_gpus_per_node: 8
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
|
||||
seed: 1
|
||||
total_train_epochs: 10
|
||||
total_train_steps: null
|
||||
tokenizer_path: ${actor.path}
|
||||
allocation_mode: sglang.d96p1t1+d32p1t1
|
||||
async_training: true
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: 400
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: true
|
||||
|
||||
gconfig:
|
||||
n_samples: 16
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 30720
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1e-5
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 4
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 512
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: online
|
||||
|
||||
# Launcher
|
||||
launcher:
|
||||
inference_server_cpus_per_gpu: 15
|
||||
inference_server_mem_per_gpu: 153600
|
||||
trainer_cpus_per_gpu: 15
|
||||
trainer_mem_per_gpu: 153600
|
||||
slurm:
|
||||
mount: /storage:/storage
|
||||
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
||||
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
|
@ -41,7 +41,7 @@ actor:
|
|||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-6
|
||||
lr: 1e-5
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
@ -18,7 +18,9 @@ from arealite.utils.saver import Saver
|
|||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
from realhf.base import logging, seeding, stats_tracker
|
||||
|
||||
logger = logging.getLogger("GSM8K grpo")
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
|
@ -36,54 +38,22 @@ def get_gsm8k_dataset(split, rank, world_size):
|
|||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
|
||||
# Adapted from verl.
|
||||
def extract_solution(solution_str, method="strict") -> str | None:
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
final_answer = None
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if len(solutions) == 0:
|
||||
final_answer = None
|
||||
else:
|
||||
# take the last solution
|
||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
||||
elif method == "flexible":
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ["", "."]
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
from realhf.impl.dataset.math_parser import extract_answer
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
sol = extract_answer(completions, data_name="math")
|
||||
ans = extract_solution(solution_str=answer, method="strict")
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
return int(sol.strip() == ans.strip())
|
||||
return int(process_results(completions, answer)[0])
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
|
@ -133,6 +103,9 @@ def main_grpo():
|
|||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
dump_dir=os.path.join(
|
||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||
),
|
||||
)
|
||||
|
||||
# Run training.
|
||||
|
@ -190,13 +163,14 @@ def main_grpo():
|
|||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path = os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver),
|
||||
"update_weights",
|
||||
str(global_step),
|
||||
),
|
||||
str(global_step + 1),
|
||||
)
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=path,
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
|
@ -206,6 +180,7 @@ def main_grpo():
|
|||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
rollout.set_version(global_step + 1)
|
||||
|
@ -253,4 +228,4 @@ def main_grpo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
||||
main(sys.argv[1:])
|
||||
|
|
|
@ -35,8 +35,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
|||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
|
||||
|
||||
def main_sft():
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, SFTConfig)
|
||||
config: SFTConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
|
@ -121,4 +121,4 @@ def main_sft():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sft()
|
||||
main(sys.argv[1:])
|
||||
|
|
Loading…
Reference in New Issue