mirror of https://github.com/inclusionAI/AReaL
[Feature] [lite] Merge from internal dev repo (#189)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * format * format * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
f68a4f677d
commit
18f8a056b6
|
@ -676,20 +676,6 @@ class ClusterSpecConfig:
|
||||||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
gpu_type: str = field(
|
|
||||||
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
|
|
||||||
)
|
|
||||||
mount: str = field(
|
|
||||||
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
|
||||||
)
|
|
||||||
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
|
|
||||||
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
|
|
||||||
gpu_infer_image: str = field(
|
|
||||||
default="", metadata={"help": "slurm image for LLM inference."}
|
|
||||||
)
|
|
||||||
node_name_prefix: str = field(
|
|
||||||
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
|
|
||||||
)
|
|
||||||
n_nodes: int = field(
|
n_nodes: int = field(
|
||||||
default=32,
|
default=32,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -725,6 +711,72 @@ class DatasetConfig:
|
||||||
drop_last: bool = field(default=True)
|
drop_last: bool = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlurmLauncherConfig:
|
||||||
|
"""Configuration for launching the SGLang server with Slurm."""
|
||||||
|
|
||||||
|
srun_additional_args: str = field(
|
||||||
|
default="--overlap --mpi=pmi2 -K --chdir $PWD",
|
||||||
|
metadata={"help": "Additional arguments to pass to the srun command."},
|
||||||
|
)
|
||||||
|
container_type: str = field(
|
||||||
|
default="apptainer",
|
||||||
|
metadata={
|
||||||
|
"help": "Type of containers used in slurm",
|
||||||
|
"choices": ["apptainer", "none"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mount: str = field(
|
||||||
|
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
||||||
|
)
|
||||||
|
trainer_image: str = field(
|
||||||
|
default="", metadata={"help": "slurm image for trainers."}
|
||||||
|
)
|
||||||
|
inference_server_image: str = field(
|
||||||
|
default="", metadata={"help": "slurm image for LLM inference."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LauncherConfig:
|
||||||
|
"""Configuration for launching the SGLang server."""
|
||||||
|
|
||||||
|
inference_server_cpus_per_gpu: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
|
||||||
|
)
|
||||||
|
inference_server_mem_per_gpu: int = field(
|
||||||
|
default=32 * 1024,
|
||||||
|
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
|
||||||
|
)
|
||||||
|
trainer_cpus_per_gpu: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Number of CPUs allocated per GPU for training. "},
|
||||||
|
)
|
||||||
|
trainer_mem_per_gpu: int = field(
|
||||||
|
default=32 * 1024,
|
||||||
|
metadata={"help": "Memory allocated per GPU for training in MB. "},
|
||||||
|
)
|
||||||
|
inference_server_env_vars: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "Environment variables for inference server, seperated by commas. "
|
||||||
|
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trainer_env_vars: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "Environment variables for training, seperated by commas. "
|
||||||
|
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||||
|
},
|
||||||
|
)
|
||||||
|
slurm: SlurmLauncherConfig = field(
|
||||||
|
default_factory=SlurmLauncherConfig,
|
||||||
|
metadata={"help": "Slurm launcher configuration."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseExperimentConfig:
|
class BaseExperimentConfig:
|
||||||
# NOTE: we need this unified config class because different experiments
|
# NOTE: we need this unified config class because different experiments
|
||||||
|
@ -742,12 +794,6 @@ class BaseExperimentConfig:
|
||||||
default_factory=ClusterSpecConfig,
|
default_factory=ClusterSpecConfig,
|
||||||
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
||||||
)
|
)
|
||||||
n_nodes: int = field(
|
|
||||||
default=1, metadata={"help": "Number of nodes for experiment."}
|
|
||||||
)
|
|
||||||
n_gpus_per_node: int = field(
|
|
||||||
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
|
||||||
)
|
|
||||||
allocation_mode: str = field(
|
allocation_mode: str = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -785,6 +831,7 @@ class BaseExperimentConfig:
|
||||||
|
|
||||||
server_only: bool = False
|
server_only: bool = False
|
||||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||||
|
launcher: LauncherConfig = field(default_factory=LauncherConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -266,6 +266,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
|
|
||||||
# Scale loss for accumulation
|
# Scale loss for accumulation
|
||||||
# Revert gradient averaging across dp ranks
|
# Revert gradient averaging across dp ranks
|
||||||
|
# FIXME: should be DP size
|
||||||
loss_scale *= self.world_size
|
loss_scale *= self.world_size
|
||||||
|
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
|
@ -286,8 +287,6 @@ class BaseHFEngine(TrainEngine):
|
||||||
update_successful = True
|
update_successful = True
|
||||||
|
|
||||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
# Optimizer step
|
|
||||||
self.optimizer.step()
|
|
||||||
return dict(
|
return dict(
|
||||||
update_successful=float(update_successful),
|
update_successful=float(update_successful),
|
||||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import dis
|
|
||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -14,14 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
get_model_state_dict,
|
get_model_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from transformers import PreTrainedTokenizerFast
|
||||||
from transformers import (
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
get_constant_schedule_with_warmup,
|
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
|
@ -232,6 +222,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||||
):
|
):
|
||||||
|
self.model.set_requires_gradient_sync(i == len(mb_list.mbs) - 1)
|
||||||
outputs = self.model(**padded_mb_input)
|
outputs = self.model(**padded_mb_input)
|
||||||
|
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
|
@ -258,8 +249,6 @@ class FSDPEngine(BaseHFEngine):
|
||||||
update_successful = True
|
update_successful = True
|
||||||
|
|
||||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
# Optimizer step
|
|
||||||
self.optimizer.step()
|
|
||||||
return dict(
|
return dict(
|
||||||
update_successful=float(update_successful),
|
update_successful=float(update_successful),
|
||||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||||
|
|
|
@ -58,15 +58,9 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
||||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||||
)
|
)
|
||||||
for i in range(cu_seqlens.shape[0] - 1):
|
for i in range(cu_seqlens.shape[0] - 1):
|
||||||
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
m = loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]]
|
||||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
logp = logprobs[cu_seqlens[i] : cu_seqlens[i + 1]]
|
||||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (m.count_nonzero())
|
||||||
cu_seqlens,
|
|
||||||
logprobs.shape,
|
|
||||||
)
|
|
||||||
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
|
|
||||||
m.numel() - m.count_nonzero()
|
|
||||||
)
|
|
||||||
|
|
||||||
## Loggin stats
|
## Loggin stats
|
||||||
stats_tracker.denominator(
|
stats_tracker.denominator(
|
||||||
|
|
|
@ -0,0 +1,410 @@
|
||||||
|
import getpass
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray.exceptions
|
||||||
|
from ray.runtime_env import RuntimeEnv
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
import realhf.base.logging as logging
|
||||||
|
from arealite.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
|
LauncherConfig,
|
||||||
|
SGLangConfig,
|
||||||
|
parse_cli_args,
|
||||||
|
to_structured_cfg,
|
||||||
|
)
|
||||||
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from arealite.utils.launcher import (
|
||||||
|
get_env_vars,
|
||||||
|
validate_config_for_distributed_launcher,
|
||||||
|
wait_sglang_server_addrs,
|
||||||
|
)
|
||||||
|
from arealite.utils.ray import get_placement_group_master_ip_and_port
|
||||||
|
from realhf.base import logging, name_resolve, names
|
||||||
|
from realhf.scheduler.client import JobException, JobState
|
||||||
|
|
||||||
|
logger = logging.getLogger("RayLauncher")
|
||||||
|
|
||||||
|
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
|
||||||
|
DEFAULT_MAIN_FUNC_NAME = "main"
|
||||||
|
|
||||||
|
|
||||||
|
def run_func(file_path, function_name, *args, **kwargs):
|
||||||
|
# Convert the file path to a module name
|
||||||
|
module_name = file_path.replace("/", "_").replace(".", "_")
|
||||||
|
|
||||||
|
# Load the module from file path
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Get the function and execute it
|
||||||
|
try:
|
||||||
|
function = getattr(module, function_name)
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Function '{function_name}' not found in module '{module_name}'. "
|
||||||
|
f"Please ensure the name of the main function in your entry point "
|
||||||
|
f"is '{function_name}'."
|
||||||
|
) from e
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RayLauncher:
|
||||||
|
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.trial_name = trial_name
|
||||||
|
self.fileroot = fileroot
|
||||||
|
|
||||||
|
# job_name to ray future
|
||||||
|
self.jobs = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_name(self):
|
||||||
|
return f"{self.experiment_name}_{self.trial_name}"
|
||||||
|
|
||||||
|
def log_path_of(self, job_name: str) -> str:
|
||||||
|
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||||
|
os.makedirs(log_path, exist_ok=True)
|
||||||
|
return os.path.join(log_path, f"{job_name}.log")
|
||||||
|
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
file_path: str,
|
||||||
|
func_name: str,
|
||||||
|
args: List[str], # arguments to pass to the function
|
||||||
|
gpus: int,
|
||||||
|
cpus: int,
|
||||||
|
mem: int, # MB
|
||||||
|
env_vars: Optional[Dict] = None,
|
||||||
|
placement_group: Optional[PlacementGroup] = None,
|
||||||
|
bundle_index: int = -1,
|
||||||
|
kwargs: Optional[
|
||||||
|
Dict[str, str]
|
||||||
|
] = None, # keyword arguments to pass to the function
|
||||||
|
):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
runtime_env = RuntimeEnv(
|
||||||
|
env_vars=env_vars or dict(),
|
||||||
|
)
|
||||||
|
scheduling_strategy = (
|
||||||
|
PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=bundle_index,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
)
|
||||||
|
if placement_group is not None
|
||||||
|
else "DEFAULT"
|
||||||
|
)
|
||||||
|
future = ray.remote(
|
||||||
|
num_cpus=cpus,
|
||||||
|
num_gpus=gpus,
|
||||||
|
memory=mem * 1024 * 1024, # Convert MB to bytes
|
||||||
|
runtime_env=runtime_env,
|
||||||
|
scheduling_strategy=scheduling_strategy,
|
||||||
|
)(run_func).remote(file_path, func_name, *args, **kwargs)
|
||||||
|
self.jobs[job_name] = future
|
||||||
|
return future
|
||||||
|
|
||||||
|
def submit_array(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
file_path: str,
|
||||||
|
func_name: str,
|
||||||
|
count: int,
|
||||||
|
nodes: int,
|
||||||
|
list_args: List[List],
|
||||||
|
gpus_per_task: int,
|
||||||
|
cpus_per_task: int,
|
||||||
|
mem_per_task: int, # MB
|
||||||
|
list_kwargs: List[Dict] | None = None,
|
||||||
|
env_vars: Optional[Dict] = None,
|
||||||
|
amend_torch_dist_env: bool = False,
|
||||||
|
):
|
||||||
|
"""Submit an array of jobs to Ray with ray placement groups.
|
||||||
|
|
||||||
|
Note: Here we use `ray.remote` instead of `ray job submit` since `ray job submit`
|
||||||
|
does not support placement groups, and can not specify which node to run the job on.
|
||||||
|
Therefore we could not know the IP address of jobs for torch distributed initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if count % nodes != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Count {count} is not divisible by nodes {nodes}. "
|
||||||
|
"Please ensure that count is a multiple of nodes."
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(list_args) == count
|
||||||
|
), f"Length of list_args {len(list_args)} does not match count {count}."
|
||||||
|
if list_kwargs is not None:
|
||||||
|
assert (
|
||||||
|
len(list_kwargs) == count
|
||||||
|
), f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
|
||||||
|
|
||||||
|
tasks_per_node = count // nodes
|
||||||
|
gpus_per_node = gpus_per_task * tasks_per_node
|
||||||
|
cpus_per_node = cpus_per_task * tasks_per_node
|
||||||
|
mem_per_node = mem_per_task * tasks_per_node
|
||||||
|
|
||||||
|
placement_group = ray.util.placement_group(
|
||||||
|
bundles=[
|
||||||
|
{
|
||||||
|
"CPU": cpus_per_node,
|
||||||
|
"GPU": gpus_per_node,
|
||||||
|
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
|
||||||
|
}
|
||||||
|
]
|
||||||
|
* nodes,
|
||||||
|
strategy="STRICT_SPREAD",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ray.get(placement_group.ready(), timeout=30)
|
||||||
|
except ray.exceptions.GetTimeoutError as e:
|
||||||
|
logger.error(
|
||||||
|
"Ray placement group timeout, please check if the resource requirement "
|
||||||
|
"for your experiment exceeds the available resources in the cluster. \n"
|
||||||
|
f"ray.nodes(): {ray.nodes()} \n"
|
||||||
|
f"Placement Group bundles: "
|
||||||
|
f"cpus_per_node={cpus_per_node}, gpus_per_node={gpus_per_node}, "
|
||||||
|
f"mem_per_node={mem_per_node}MB, nodes={nodes}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if amend_torch_dist_env:
|
||||||
|
host_ip, port = get_placement_group_master_ip_and_port(placement_group)
|
||||||
|
logger.info(
|
||||||
|
f"Amend torch distributed env vars: "
|
||||||
|
f"MASTER_ADDR={host_ip}, PORT={port}"
|
||||||
|
)
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for i in range(count):
|
||||||
|
args = list_args[i]
|
||||||
|
kwargs = list_kwargs[i] if list_kwargs is not None else {}
|
||||||
|
|
||||||
|
# manage environment variables
|
||||||
|
env_vars = env_vars or {}
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in env_vars:
|
||||||
|
logger.warning(
|
||||||
|
"Setting CUDA_VISIBLE_DEVICES before running ray jobs may result in unexpected behavior."
|
||||||
|
)
|
||||||
|
|
||||||
|
node_id = i // tasks_per_node
|
||||||
|
_env_vars = {
|
||||||
|
**env_vars,
|
||||||
|
}
|
||||||
|
|
||||||
|
if amend_torch_dist_env:
|
||||||
|
assert gpus_per_task == 1
|
||||||
|
# NOTE: Here we only provide environment variables for torch distributed
|
||||||
|
# initialization, and LOCAL_RANK for torch.device.
|
||||||
|
# Other environment variables automatically set by torchrun are not set, and
|
||||||
|
# they should be never accessed in trainer code.
|
||||||
|
_env_vars.update(
|
||||||
|
{
|
||||||
|
"RANK": str(i),
|
||||||
|
"WORLD_SIZE": str(count),
|
||||||
|
# Ray will automatically isolate CUDA_VISIBLE_DEVICES for each GPU
|
||||||
|
"LOCAL_RANK": "0",
|
||||||
|
"MASTER_ADDR": str(host_ip),
|
||||||
|
"MASTER_PORT": str(port),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
future = self.submit(
|
||||||
|
job_name=f"{job_name}:{i}",
|
||||||
|
file_path=file_path,
|
||||||
|
func_name=func_name,
|
||||||
|
args=args,
|
||||||
|
gpus=gpus_per_task,
|
||||||
|
cpus=cpus_per_task,
|
||||||
|
mem=mem_per_task,
|
||||||
|
env_vars=_env_vars,
|
||||||
|
placement_group=placement_group,
|
||||||
|
bundle_index=node_id,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
return futures
|
||||||
|
|
||||||
|
def stop(self, job_name: str, force: bool = False):
|
||||||
|
"""Stop a job by name."""
|
||||||
|
if job_name in self.jobs:
|
||||||
|
future = self.jobs[job_name]
|
||||||
|
try:
|
||||||
|
ray.cancel(future, force=force)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel job {job_name}: {e}")
|
||||||
|
return
|
||||||
|
self.jobs.pop(job_name, None)
|
||||||
|
logger.info(f"Job {job_name} stopped.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Job {job_name} not found in running jobs.")
|
||||||
|
|
||||||
|
def stop_all(self, force: bool = False):
|
||||||
|
"""Stop all jobs."""
|
||||||
|
for job_name in list(self.jobs.keys()):
|
||||||
|
self.stop(job_name, force=force)
|
||||||
|
logger.info("All jobs stopped.")
|
||||||
|
self.jobs.clear()
|
||||||
|
|
||||||
|
def wait(
|
||||||
|
self, check_status=(JobState.FAILED,), remove_status=(JobState.COMPLETED,)
|
||||||
|
):
|
||||||
|
"""Check every RAY_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
|
||||||
|
If a ray job returns, its status changes to JobState.COMPLETED.
|
||||||
|
If a ray job failed, its status changes to JobState.FAILED.
|
||||||
|
If any job is in check_status, stop all jobs at once.
|
||||||
|
If any job is in remove status, remove them from job list.
|
||||||
|
Return if all jobs are removed from job list, or some job is in check status.
|
||||||
|
"""
|
||||||
|
for status in list(check_status) + list(remove_status):
|
||||||
|
assert status in [
|
||||||
|
JobState.COMPLETED,
|
||||||
|
JobState.FAILED,
|
||||||
|
], "In RayLauncher.wait, we only check completed or failed jobs."
|
||||||
|
logger.info(f"Waiting for {len(self.jobs)} jobs.")
|
||||||
|
while self.jobs:
|
||||||
|
job_status = {}
|
||||||
|
for job_name, future in list(self.jobs.items()):
|
||||||
|
try:
|
||||||
|
r = ray.get(future, timeout=0.1)
|
||||||
|
logger.info(f"Job {job_name} completed with result: {r}")
|
||||||
|
job_status[job_name] = JobState.COMPLETED
|
||||||
|
except ray.exceptions.RayTaskError as e:
|
||||||
|
logger.error(f"Job {job_name} failed with error: {e}.")
|
||||||
|
job_status[job_name] = JobState.FAILED
|
||||||
|
except ray.exceptions.GetTimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for job_name, status in job_status.items():
|
||||||
|
if status in check_status:
|
||||||
|
logger.info(f"Job {job_name} is {status}, stopping all jobs.")
|
||||||
|
self.stop_all(force=True)
|
||||||
|
return
|
||||||
|
if status in remove_status:
|
||||||
|
logger.info(f"Job {job_name} is {status}, removed.")
|
||||||
|
self.jobs.pop(job_name)
|
||||||
|
|
||||||
|
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
|
||||||
|
|
||||||
|
|
||||||
|
def ray_main():
|
||||||
|
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>]
|
||||||
|
ray.init()
|
||||||
|
config, config_file = parse_cli_args(sys.argv[2:])
|
||||||
|
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||||
|
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||||
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||||
|
validate_config_for_distributed_launcher(config)
|
||||||
|
|
||||||
|
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||||
|
name_resolve.clear_subtree(
|
||||||
|
names.trial_root(
|
||||||
|
experiment_name=config.experiment_name, trial_name=config.trial_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
n_nodes = config.cluster.n_nodes
|
||||||
|
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||||
|
launcher = RayLauncher(
|
||||||
|
experiment_name=config.experiment_name,
|
||||||
|
trial_name=config.trial_name,
|
||||||
|
fileroot=config.cluster.fileroot,
|
||||||
|
)
|
||||||
|
allocation_mode = config.allocation_mode
|
||||||
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
|
sglang_cmds = []
|
||||||
|
sglang_addrs = []
|
||||||
|
n_sglang_nodes = 0
|
||||||
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||||
|
# Launcher should launch SGLang servers according to allocation mode.
|
||||||
|
sglang_tp_size = allocation_mode.gen_tp_size
|
||||||
|
n_sglang_servers = allocation_mode.gen_dp_size
|
||||||
|
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||||
|
|
||||||
|
base_seed = config.sglang.random_seed
|
||||||
|
sglang_args_list = [
|
||||||
|
[sys.argv[2:] + [f"sglang.random_seed={base_seed + i}"]]
|
||||||
|
for i in range(n_sglang_servers)
|
||||||
|
]
|
||||||
|
sglang_entry_point = str(
|
||||||
|
pathlib.Path(__file__).resolve().parent.joinpath("sglang_server.py")
|
||||||
|
)
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="llm_server",
|
||||||
|
file_path=sglang_entry_point,
|
||||||
|
func_name=DEFAULT_MAIN_FUNC_NAME,
|
||||||
|
count=n_sglang_servers,
|
||||||
|
nodes=n_sglang_nodes,
|
||||||
|
list_args=sglang_args_list,
|
||||||
|
gpus_per_task=sglang_tp_size,
|
||||||
|
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||||
|
* sglang_tp_size,
|
||||||
|
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||||
|
env_vars=get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.inference_server_env_vars,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Get SGLang server addresses via name_resolve
|
||||||
|
try:
|
||||||
|
sglang_addrs = wait_sglang_server_addrs(
|
||||||
|
config.experiment_name,
|
||||||
|
config.trial_name,
|
||||||
|
n_sglang_servers,
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
launcher.stop_all(force=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||||
|
trainer_entry_point = sys.argv[1]
|
||||||
|
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
|
||||||
|
trainer_args_list = [[sys.argv[2:]] for _ in range(n_trainer_processes)]
|
||||||
|
if not config.server_only:
|
||||||
|
# In ray, we launch trainer in the granularity of processes (1 GPU per process)
|
||||||
|
# We amend environment variable similar to torchrun to ensure correct initialization of
|
||||||
|
# torch distributed.
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="trainer",
|
||||||
|
file_path=trainer_entry_point,
|
||||||
|
func_name=DEFAULT_MAIN_FUNC_NAME,
|
||||||
|
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
|
||||||
|
nodes=trainer_n_nodes,
|
||||||
|
list_args=trainer_args_list,
|
||||||
|
gpus_per_task=1,
|
||||||
|
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
|
||||||
|
mem_per_task=config.launcher.trainer_mem_per_gpu,
|
||||||
|
env_vars=dict(
|
||||||
|
**get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.trainer_env_vars,
|
||||||
|
),
|
||||||
|
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||||
|
),
|
||||||
|
amend_torch_dist_env=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
launcher.wait(check_status=(JobState.COMPLETED, JobState.FAILED))
|
||||||
|
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||||
|
launcher.stop_all(force=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# usage: python -m arealite.launcher.ray \
|
||||||
|
# <entry_point> --config <config_path> [<additional_args>] \
|
||||||
|
# launcher.ray.main_func_name=<main_func_name_in_entry_point>
|
||||||
|
ray_main()
|
|
@ -0,0 +1,197 @@
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from arealite.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
|
NameResolveConfig,
|
||||||
|
SGLangConfig,
|
||||||
|
parse_cli_args,
|
||||||
|
to_structured_cfg,
|
||||||
|
)
|
||||||
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from arealite.utils.launcher import TRITON_CACHE_PATH
|
||||||
|
from arealite.utils.network import find_free_ports, gethostip
|
||||||
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
|
|
||||||
|
logger = logging.getLogger("SGLangServer Wrapper")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_shell_command(command: str) -> subprocess.Popen:
|
||||||
|
"""
|
||||||
|
Execute a shell command and return its process handle.
|
||||||
|
"""
|
||||||
|
# Replace newline continuations and split the command string.
|
||||||
|
command = command.replace("\\\n", " ").replace("\\", " ")
|
||||||
|
parts = command.split()
|
||||||
|
_env = os.environ.copy()
|
||||||
|
# To avoid DirectoryNotEmpty error caused by triton
|
||||||
|
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
|
||||||
|
unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4()))
|
||||||
|
_env["TRITON_CACHE_PATH"] = unique_triton_cache_path
|
||||||
|
return subprocess.Popen(
|
||||||
|
parts,
|
||||||
|
text=True,
|
||||||
|
env=_env,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sglang_patch():
|
||||||
|
p = Path(os.path.dirname(__file__))
|
||||||
|
patch_path = str(
|
||||||
|
p.parent.parent
|
||||||
|
/ "patch"
|
||||||
|
/ "sglang"
|
||||||
|
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_path = ""
|
||||||
|
sglang_meta = subprocess.check_output(
|
||||||
|
"python3 -m pip show sglang", shell=True
|
||||||
|
).decode("ascii")
|
||||||
|
for line in sglang_meta.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("Editable project location: "):
|
||||||
|
target_path = str(Path(line.split(": ")[1]).parent)
|
||||||
|
|
||||||
|
if target_path:
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
["git", "apply", patch_path],
|
||||||
|
cwd=target_path,
|
||||||
|
stderr=sys.stdout,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
)
|
||||||
|
proc.wait()
|
||||||
|
logger.info(f"Applied SGLang patch at {target_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def launch_server_cmd(command: str):
|
||||||
|
"""
|
||||||
|
Launch the server using the given command.
|
||||||
|
If no port is specified, a free port is reserved.
|
||||||
|
"""
|
||||||
|
if not ray.is_initialized():
|
||||||
|
apply_sglang_patch()
|
||||||
|
process = execute_shell_command(command)
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
||||||
|
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: The base URL of the server
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{base_url}/v1/models",
|
||||||
|
headers={"Authorization": "Bearer None"},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
time.sleep(5)
|
||||||
|
break
|
||||||
|
|
||||||
|
if timeout and time.time() - start_time > timeout:
|
||||||
|
raise TimeoutError("Server did not become ready within timeout period")
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangServerWrapper:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_name: str,
|
||||||
|
trial_name: str,
|
||||||
|
sglang_config: SGLangConfig,
|
||||||
|
tp_size: int,
|
||||||
|
n_gpus_per_node: int,
|
||||||
|
):
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.trial_name = trial_name
|
||||||
|
self.config = sglang_config
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.server_process = None
|
||||||
|
self.n_gpus_per_node = n_gpus_per_node
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
gpus_per_server = len(os.getenv("CUDA_VISIBLE_DEVICES").split(","))
|
||||||
|
server_local_idx = (
|
||||||
|
int(os.getenv("CUDA_VISIBLE_DEVICES").split(",")[0]) // gpus_per_server
|
||||||
|
)
|
||||||
|
n_servers_per_node = max(1, self.n_gpus_per_node // gpus_per_server)
|
||||||
|
ports_per_server = 40000 // n_servers_per_node
|
||||||
|
port_range = (
|
||||||
|
server_local_idx * ports_per_server + 10000,
|
||||||
|
(server_local_idx + 1) * ports_per_server + 10000,
|
||||||
|
)
|
||||||
|
server_port, dist_init_port = find_free_ports(2, port_range)
|
||||||
|
|
||||||
|
dist_init_addr = f"localhost:{dist_init_port}"
|
||||||
|
host_ip = gethostip()
|
||||||
|
|
||||||
|
cmd = SGLangConfig.build_cmd(
|
||||||
|
self.config,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
base_gpu_id=0,
|
||||||
|
host=host_ip,
|
||||||
|
port=server_port,
|
||||||
|
dist_init_addr=dist_init_addr,
|
||||||
|
)
|
||||||
|
self.server_process = launch_server_cmd(cmd)
|
||||||
|
wait_for_server(f"http://{host_ip}:{server_port}")
|
||||||
|
|
||||||
|
name = names.gen_servers(self.experiment_name, self.trial_name)
|
||||||
|
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")
|
||||||
|
|
||||||
|
logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}")
|
||||||
|
return_code = self.server_process.wait()
|
||||||
|
logger.info(
|
||||||
|
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Terminating SGLang server process...")
|
||||||
|
self.server_process.terminate()
|
||||||
|
self.server_process.wait()
|
||||||
|
logger.info("SGLang server process terminated.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
config, _ = parse_cli_args(argv)
|
||||||
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||||
|
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||||
|
config.cluster.name_resolve = to_structured_cfg(
|
||||||
|
config.cluster.name_resolve, NameResolveConfig
|
||||||
|
)
|
||||||
|
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||||
|
|
||||||
|
allocation_mode = config.allocation_mode
|
||||||
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
|
assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG
|
||||||
|
tp_size = allocation_mode.gen_tp_size
|
||||||
|
|
||||||
|
sglang_server = SGLangServerWrapper(
|
||||||
|
config.experiment_name,
|
||||||
|
config.trial_name,
|
||||||
|
config.sglang,
|
||||||
|
tp_size,
|
||||||
|
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||||
|
)
|
||||||
|
sglang_server.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(sys.argv[1:])
|
|
@ -0,0 +1,528 @@
|
||||||
|
import getpass
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import realhf.base.logging as logging
|
||||||
|
from arealite.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
|
LauncherConfig,
|
||||||
|
SGLangConfig,
|
||||||
|
parse_cli_args,
|
||||||
|
to_structured_cfg,
|
||||||
|
)
|
||||||
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from arealite.utils.launcher import (
|
||||||
|
get_env_vars,
|
||||||
|
validate_config_for_distributed_launcher,
|
||||||
|
wait_sglang_server_addrs,
|
||||||
|
)
|
||||||
|
from arealite.utils.slurm import (
|
||||||
|
APPTAINER_CMD_TEMPLATE,
|
||||||
|
SBATCH_SCRIPT_TEMPLATE,
|
||||||
|
SRUN_CMD_TEMPLATE,
|
||||||
|
cancel_jobs,
|
||||||
|
query_jobs,
|
||||||
|
)
|
||||||
|
from realhf.base import logging, name_resolve, names
|
||||||
|
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||||
|
|
||||||
|
logger = logging.getLogger("SlurmLauncher")
|
||||||
|
|
||||||
|
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
|
||||||
|
|
||||||
|
|
||||||
|
class SlurmLauncher:
|
||||||
|
def __init__(
|
||||||
|
self, experiment_name: str, trial_name: str, fileroot: str, container_type: str
|
||||||
|
):
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.trial_name = trial_name
|
||||||
|
self.fileroot = fileroot
|
||||||
|
self.container_type = container_type
|
||||||
|
|
||||||
|
# slurm_job_id -> JobInfo
|
||||||
|
self.jobs: Dict[int, JobInfo] = {}
|
||||||
|
self.job_names = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_name(self) -> str:
|
||||||
|
"""Returns the run name of this launcher."""
|
||||||
|
return f"{self.experiment_name}_{self.trial_name}"
|
||||||
|
|
||||||
|
def slurm_name(self, job_name: str) -> str:
|
||||||
|
"""Returns the slurm name of a job."""
|
||||||
|
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
|
||||||
|
|
||||||
|
def log_path_of(self, job_name: str) -> str:
|
||||||
|
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||||
|
os.makedirs(log_path, exist_ok=True)
|
||||||
|
return os.path.join(log_path, f"{job_name}.log")
|
||||||
|
|
||||||
|
def sbatch_path_of(self, job_name: str) -> str:
|
||||||
|
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||||
|
os.makedirs(sbatch_path, exist_ok=True)
|
||||||
|
return os.path.join(sbatch_path, f"{job_name}.sh")
|
||||||
|
|
||||||
|
def submit(self, job_name, cmd, **kwargs):
|
||||||
|
"""Submits and launch a job with SBATCH.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd (str or List[str]): The core command to be executed.
|
||||||
|
"""
|
||||||
|
return self.submit_array(job_name, cmd, count=1, **kwargs)
|
||||||
|
|
||||||
|
def find_job_id(self, job_name: str):
|
||||||
|
job_name = self.slurm_name(job_name)
|
||||||
|
for job_id, job_info in self.jobs.items():
|
||||||
|
if job_info.name == job_name:
|
||||||
|
return job_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def submit_array(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
cmd: List[str] | str,
|
||||||
|
count: int,
|
||||||
|
nodes: int,
|
||||||
|
n_gpus_per_node: int,
|
||||||
|
cpus_per_task: int,
|
||||||
|
mem_per_task: int, # MB
|
||||||
|
container_image: str,
|
||||||
|
srun_additional_args: str = "",
|
||||||
|
container_mounts: Optional[str] = None,
|
||||||
|
env_vars: Optional[Dict] = None,
|
||||||
|
nodelist: Optional[str] = None,
|
||||||
|
exclude: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Submits and launch a job array with SBATCH.
|
||||||
|
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name (str): The job name of the job array. The actual slurm name will be
|
||||||
|
`<experiment_name>_<trial_name>:<job_name>`.
|
||||||
|
cmd (str or List[str]): The core command to be executed.
|
||||||
|
count (int): The number of jobs in the array.
|
||||||
|
"""
|
||||||
|
assert job_name not in self.job_names, (
|
||||||
|
f"Job {job_name} is already submitted. "
|
||||||
|
"Please use a different job name or stop the existing job."
|
||||||
|
)
|
||||||
|
if isinstance(cmd, str):
|
||||||
|
cmd = [cmd]
|
||||||
|
assert len(cmd) == count, (
|
||||||
|
f"Command length {len(cmd)} does not match the job count {count}. "
|
||||||
|
"Please provide a command for each job in the array."
|
||||||
|
)
|
||||||
|
assert count % nodes == 0, (
|
||||||
|
f"Job count {count} must be divisible by the number of nodes {nodes}. "
|
||||||
|
"Please adjust the job count or the number of nodes."
|
||||||
|
)
|
||||||
|
ntasks_per_node = count // nodes
|
||||||
|
assert n_gpus_per_node % ntasks_per_node == 0, (
|
||||||
|
"GPUs must be evenly distributed across tasks. "
|
||||||
|
f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}."
|
||||||
|
)
|
||||||
|
|
||||||
|
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
|
||||||
|
mem_per_node = (
|
||||||
|
mem_per_task * count // nodes + 1024 * 10
|
||||||
|
) # make sure slurm does not run out of resources
|
||||||
|
|
||||||
|
sbatch_options = [
|
||||||
|
f"--job-name={self.slurm_name(job_name)}",
|
||||||
|
f"--output={self.log_path_of(job_name)}",
|
||||||
|
"--open-mode=append",
|
||||||
|
"--no-requeue",
|
||||||
|
f"--nodes={nodes}-{nodes}",
|
||||||
|
f"--ntasks-per-node={ntasks_per_node}",
|
||||||
|
f"--gres=gpu:{n_gpus_per_node}",
|
||||||
|
f"--cpus-per-task={cpus_per_task}",
|
||||||
|
f"--mem={mem_per_node}M",
|
||||||
|
]
|
||||||
|
|
||||||
|
if nodelist:
|
||||||
|
sbatch_options.append(f"--nodelist={nodelist}")
|
||||||
|
if exclude:
|
||||||
|
sbatch_options.append(f"--exclude={exclude}")
|
||||||
|
|
||||||
|
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
|
||||||
|
|
||||||
|
if env_vars is None:
|
||||||
|
env_vars = dict()
|
||||||
|
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
|
||||||
|
assert (
|
||||||
|
"CUDA_VISIBLE_DEVICES" not in env_vars
|
||||||
|
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||||
|
|
||||||
|
srun_cmds = []
|
||||||
|
for i in range(count):
|
||||||
|
# resolve CUDA_VISIBLE_DEVICES for each task
|
||||||
|
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
|
||||||
|
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
|
||||||
|
node_id = i // ntasks_per_node
|
||||||
|
_env_vars = {
|
||||||
|
**env_vars,
|
||||||
|
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||||
|
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
# Prepare the command for each job in the array
|
||||||
|
job_cmd = cmd[i]
|
||||||
|
|
||||||
|
if self.container_type == "apptainer":
|
||||||
|
env_string = " ".join(
|
||||||
|
"--env {}={}".format(k, v) for k, v in _env_vars.items()
|
||||||
|
)
|
||||||
|
apptainer_cmd = APPTAINER_CMD_TEMPLATE.format(
|
||||||
|
container_mounts=container_mounts or "",
|
||||||
|
container_env_strings=env_string,
|
||||||
|
container_image=container_image,
|
||||||
|
cmd=job_cmd,
|
||||||
|
)
|
||||||
|
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||||
|
additional_args=srun_additional_args,
|
||||||
|
nodes=1,
|
||||||
|
ntasks=1,
|
||||||
|
node_id=node_id,
|
||||||
|
n_gpus_per_node=n_gpus_per_task,
|
||||||
|
cpus_per_task=cpus_per_task,
|
||||||
|
mem_per_cpu=mem_per_cpu,
|
||||||
|
cmd=apptainer_cmd,
|
||||||
|
)
|
||||||
|
elif self.container_type == "none":
|
||||||
|
env_string = "--export=" + ",".join(
|
||||||
|
"{}={}".format(k, v) for k, v in _env_vars.items()
|
||||||
|
)
|
||||||
|
srun_additional_args = srun_additional_args + " " + env_string
|
||||||
|
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||||
|
additional_args=srun_additional_args,
|
||||||
|
nodes=1,
|
||||||
|
ntasks=1,
|
||||||
|
node_id=node_id,
|
||||||
|
n_gpus_per_node=n_gpus_per_task,
|
||||||
|
cpus_per_task=cpus_per_task,
|
||||||
|
mem_per_cpu=mem_per_cpu,
|
||||||
|
cmd=job_cmd,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported container type: {self.container_type}. "
|
||||||
|
"Supported types are 'apptainer' and 'none'."
|
||||||
|
)
|
||||||
|
srun_cmds.append(srun_cmd)
|
||||||
|
|
||||||
|
srun_cmds = "\n".join(srun_cmds)
|
||||||
|
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
|
||||||
|
sbatch_options=sbatch_options_str,
|
||||||
|
srun_additional_args=srun_additional_args,
|
||||||
|
srun_cmds=srun_cmds,
|
||||||
|
)
|
||||||
|
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
|
||||||
|
with open(sbatch_file_path, "w") as f:
|
||||||
|
f.write(sbatch_script)
|
||||||
|
|
||||||
|
# Submit the job
|
||||||
|
try:
|
||||||
|
output = (
|
||||||
|
subprocess.check_output(["sbatch", sbatch_file_path])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to submit job {self.slurm_name(job_name)}. "
|
||||||
|
f"For debugging, please make sure your sbatch command works "
|
||||||
|
f"and check generated sbatch file on {sbatch_file_path}."
|
||||||
|
)
|
||||||
|
logger.error(f"Error message: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
match = re.search(r"Submitted batch job (\d+)", output)
|
||||||
|
slurm_job_id = int(match.group(1)) if match else None
|
||||||
|
if slurm_job_id is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
|
||||||
|
f"sbatch output: {output}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(slurm_job_id, int)
|
||||||
|
self.jobs[slurm_job_id] = JobInfo(
|
||||||
|
name=self.slurm_name(job_name),
|
||||||
|
state=JobState.PENDING,
|
||||||
|
slurm_id=slurm_job_id,
|
||||||
|
)
|
||||||
|
self._update_all()
|
||||||
|
|
||||||
|
def stop(self, job_name, force=False):
|
||||||
|
"""Stops a running job.
|
||||||
|
|
||||||
|
Raises exception if there is no such job, but passes if the job
|
||||||
|
has stopped either successfully or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name: The job name of the job array to stop.
|
||||||
|
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||||
|
"""
|
||||||
|
signal = "SIGKILL" if force else "SIGTERM"
|
||||||
|
job_id = self.find_job_id(job_name)
|
||||||
|
if not job_id:
|
||||||
|
return
|
||||||
|
return cancel_jobs(slurm_ids=[job_id], signal=signal)
|
||||||
|
|
||||||
|
def stop_all(self, force=False):
|
||||||
|
"""Stops all running jobs."""
|
||||||
|
signal = "SIGKILL" if force else "SIGTERM"
|
||||||
|
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
|
||||||
|
|
||||||
|
def find(self, job_name) -> JobInfo | None:
|
||||||
|
"""Gets the status of a job of this job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name: The job name of the job array to find.
|
||||||
|
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JobInfo if the job is found, or None otherwise.
|
||||||
|
"""
|
||||||
|
self._update_all()
|
||||||
|
job_id = self.find_job_id(job_name)
|
||||||
|
return self.jobs[job_id] if job_id else None
|
||||||
|
|
||||||
|
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
|
||||||
|
"""Finds jobs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name_regex: job name regex.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of found JobInfo.
|
||||||
|
"""
|
||||||
|
self._update_all()
|
||||||
|
infos = []
|
||||||
|
for r in self.jobs.values():
|
||||||
|
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
|
||||||
|
if re.fullmatch(job_name_regex, job_name):
|
||||||
|
infos.append(r)
|
||||||
|
return infos
|
||||||
|
|
||||||
|
def _find_job_with_status(
|
||||||
|
self,
|
||||||
|
status: List[JobState],
|
||||||
|
) -> List[JobInfo]:
|
||||||
|
"""Finds jobs with the given status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: A list of JobState to filter jobs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of JobInfo with the given status.
|
||||||
|
"""
|
||||||
|
self._update_all()
|
||||||
|
return [r for r in self.jobs.values() if r.state in status]
|
||||||
|
|
||||||
|
def wait(
|
||||||
|
self,
|
||||||
|
timeout=None,
|
||||||
|
check_status: Tuple[JobState, ...] = (
|
||||||
|
JobState.CANCELLED,
|
||||||
|
JobState.FAILED,
|
||||||
|
JobState.NOT_FOUND,
|
||||||
|
),
|
||||||
|
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
|
||||||
|
update=False,
|
||||||
|
):
|
||||||
|
"""Waits until all jobs submitted via this client instance finish."""
|
||||||
|
# begin wait
|
||||||
|
deadline = None if timeout is None else time.time() + timeout
|
||||||
|
|
||||||
|
num_jobs_left = len(self.jobs)
|
||||||
|
left = list(self.jobs.keys())
|
||||||
|
logger.info(
|
||||||
|
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
|
||||||
|
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||||
|
)
|
||||||
|
while len(left) > 0:
|
||||||
|
if len(left) < num_jobs_left:
|
||||||
|
num_jobs_left = len(left)
|
||||||
|
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||||
|
if deadline is not None and time.time() > deadline:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
|
||||||
|
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||||
|
)
|
||||||
|
self._update_all()
|
||||||
|
left = list(self.jobs.keys())
|
||||||
|
for slurm_id in list(left):
|
||||||
|
slurm_info = self.jobs[slurm_id]
|
||||||
|
if slurm_info.slurm_id is None:
|
||||||
|
continue
|
||||||
|
if slurm_info.state in check_status:
|
||||||
|
raise JobException(
|
||||||
|
run_name=self.run_name,
|
||||||
|
worker_type=slurm_info.name,
|
||||||
|
host=slurm_info.host,
|
||||||
|
reason=slurm_info.state,
|
||||||
|
)
|
||||||
|
if slurm_info.state in remove_status:
|
||||||
|
logger.info(
|
||||||
|
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
|
||||||
|
)
|
||||||
|
left.remove(slurm_id)
|
||||||
|
if update:
|
||||||
|
self.jobs.pop(slurm_info.slurm_id)
|
||||||
|
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
|
||||||
|
|
||||||
|
def _update_all(self):
|
||||||
|
"""Updates the status of all jobs."""
|
||||||
|
try:
|
||||||
|
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
|
||||||
|
for slurm_info in slurm_infos:
|
||||||
|
assert slurm_info.slurm_id is not None
|
||||||
|
self.jobs[slurm_info.slurm_id] = slurm_info
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
logger.warning(
|
||||||
|
"Calling squeue failed. Check slurm manually if you continue to see this warning."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def slurm_main():
|
||||||
|
config, _ = parse_cli_args(sys.argv[2:])
|
||||||
|
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||||
|
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||||
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||||
|
validate_config_for_distributed_launcher(config)
|
||||||
|
|
||||||
|
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||||
|
name_resolve.clear_subtree(
|
||||||
|
names.trial_root(
|
||||||
|
experiment_name=config.experiment_name, trial_name=config.trial_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
n_nodes = config.cluster.n_nodes
|
||||||
|
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||||
|
|
||||||
|
launcher = SlurmLauncher(
|
||||||
|
experiment_name=config.experiment_name,
|
||||||
|
trial_name=config.trial_name,
|
||||||
|
fileroot=config.cluster.fileroot,
|
||||||
|
container_type=config.launcher.slurm.container_type,
|
||||||
|
)
|
||||||
|
allocation_mode = config.allocation_mode
|
||||||
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
|
sglang_cmds = []
|
||||||
|
sglang_addrs = []
|
||||||
|
n_sglang_nodes = 0
|
||||||
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||||
|
# Launcher should launch SGLang servers according to allocation mode.
|
||||||
|
sglang_tp_size = allocation_mode.gen_tp_size
|
||||||
|
n_sglang_servers = allocation_mode.gen_dp_size
|
||||||
|
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||||
|
|
||||||
|
base_seed = config.sglang.random_seed
|
||||||
|
sglang_server_cmd_template = f"python3 -m arealite.launcher.sglang_server {' '.join(sys.argv[2:])} sglang.random_seed={{seed}}"
|
||||||
|
for i in range(n_sglang_servers):
|
||||||
|
sglang_cmd = sglang_server_cmd_template.format(
|
||||||
|
seed=base_seed + i,
|
||||||
|
)
|
||||||
|
sglang_cmds.append(sglang_cmd)
|
||||||
|
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="llm_server",
|
||||||
|
cmd=sglang_cmds,
|
||||||
|
count=n_sglang_servers,
|
||||||
|
nodes=n_sglang_nodes,
|
||||||
|
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||||
|
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||||
|
* sglang_tp_size,
|
||||||
|
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||||
|
srun_additional_args=config.launcher.slurm.srun_additional_args,
|
||||||
|
container_image=config.launcher.slurm.inference_server_image,
|
||||||
|
container_mounts=config.launcher.slurm.mount,
|
||||||
|
env_vars=get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.inference_server_env_vars,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Get SGLang server addresses by name resolve
|
||||||
|
try:
|
||||||
|
sglang_addrs = wait_sglang_server_addrs(
|
||||||
|
config.experiment_name,
|
||||||
|
config.trial_name,
|
||||||
|
n_sglang_servers,
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
launcher.stop_all(force=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||||
|
# Here $head_node_ip is the IP address of the first node in the job array.
|
||||||
|
# $trainer_port is a free port on the head node.
|
||||||
|
# Both of them are obtained in by the SBATCH script.
|
||||||
|
trainer_cmd_template = (
|
||||||
|
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||||
|
f"--master-addr $head_node_ip --master-port $trainer_port {' '.join(sys.argv[1:])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_cmds = []
|
||||||
|
for i in range(trainer_n_nodes):
|
||||||
|
# In slurm, we launch trainer in the granularity of nodes with torchrun command.
|
||||||
|
trainer_cmds.append(
|
||||||
|
trainer_cmd_template.format(
|
||||||
|
nnodes=trainer_n_nodes,
|
||||||
|
nproc_per_node=config.cluster.n_gpus_per_node,
|
||||||
|
node_rank=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.server_only:
|
||||||
|
# launch trainers
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="trainer",
|
||||||
|
cmd=trainer_cmds,
|
||||||
|
count=trainer_n_nodes,
|
||||||
|
nodes=trainer_n_nodes,
|
||||||
|
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||||
|
cpus_per_task=config.launcher.trainer_cpus_per_gpu
|
||||||
|
* config.cluster.n_gpus_per_node,
|
||||||
|
mem_per_task=config.launcher.trainer_mem_per_gpu
|
||||||
|
* config.cluster.n_gpus_per_node,
|
||||||
|
container_image=config.launcher.slurm.trainer_image,
|
||||||
|
srun_additional_args=config.launcher.slurm.srun_additional_args,
|
||||||
|
container_mounts=config.launcher.slurm.mount,
|
||||||
|
env_vars=dict(
|
||||||
|
**get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.trainer_env_vars,
|
||||||
|
),
|
||||||
|
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
launcher.wait(
|
||||||
|
check_status=(
|
||||||
|
JobState.CANCELLED,
|
||||||
|
JobState.FAILED,
|
||||||
|
JobState.NOT_FOUND,
|
||||||
|
JobState.COMPLETED,
|
||||||
|
),
|
||||||
|
remove_status=(),
|
||||||
|
)
|
||||||
|
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||||
|
launcher.stop_all(force=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# usage: python -m arealite.launcher.slurm <entry_point> \
|
||||||
|
# --config <config_path> [<additional_args>]
|
||||||
|
slurm_main()
|
|
@ -2,11 +2,9 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import torch
|
|
||||||
|
|
||||||
from arealite.api.cli_args import (
|
from arealite.api.cli_args import (
|
||||||
InferenceEngineConfig,
|
InferenceEngineConfig,
|
||||||
|
@ -18,7 +16,6 @@ from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
from arealite.utils.network import find_free_ports
|
from arealite.utils.network import find_free_ports
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
|
||||||
from realhf.base import network
|
from realhf.base import network
|
||||||
|
|
||||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||||
|
|
|
@ -399,9 +399,15 @@ def pad_packed_tensor_dict(
|
||||||
padded_data[key] = new_max_seqlen
|
padded_data[key] = new_max_seqlen
|
||||||
elif torch.is_tensor(value) and value.numel() == total_length:
|
elif torch.is_tensor(value) and value.numel() == total_length:
|
||||||
# Pad the tensor to the new total length
|
# Pad the tensor to the new total length
|
||||||
padded_tensor = torch.nn.functional.pad(
|
if key == "position_ids":
|
||||||
value, (0, pad_length), value=pad_value
|
# transformers will compute flash-attn arguments (e.g., cu_seqlens_q)
|
||||||
)
|
# according to this position ids.
|
||||||
|
pad = torch.arange(pad_length, dtype=torch.long, device=value.device)
|
||||||
|
padded_tensor = torch.cat([value, pad])
|
||||||
|
else:
|
||||||
|
padded_tensor = torch.nn.functional.pad(
|
||||||
|
value, (0, pad_length), value=pad_value
|
||||||
|
)
|
||||||
padded_data[key] = padded_tensor
|
padded_data[key] = padded_tensor
|
||||||
else:
|
else:
|
||||||
padded_data[key] = value
|
padded_data[key] = value
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
import getpass
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from realhf.base import logging, name_resolve, names
|
||||||
|
|
||||||
|
logger = logging.getLogger("Launcher Utils")
|
||||||
|
|
||||||
|
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||||
|
PYTORCH_KERNEL_CACHE_PATH = (
|
||||||
|
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/"
|
||||||
|
)
|
||||||
|
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton/"
|
||||||
|
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||||
|
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||||
|
BASE_ENVIRONS = {
|
||||||
|
"TOKENIZERS_PARALLELISM": "true",
|
||||||
|
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||||
|
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||||
|
"PYTHONPATH": str(pathlib.Path(__file__).resolve().parent.parent.parent),
|
||||||
|
}
|
||||||
|
NA132_ENVIRONS = {
|
||||||
|
"NCCL_SOCKET_IFNAME": "bond0",
|
||||||
|
"NCCL_NET_PLUGIN": "",
|
||||||
|
"NCCL_IB_GID_INDEX": "3",
|
||||||
|
"NCCL_IB_TIMEOUT": "2",
|
||||||
|
"NCCL_IB_RETRY_CNT": "7",
|
||||||
|
"NCCL_IB_SL": "5",
|
||||||
|
"NCCL_IB_TC": "136",
|
||||||
|
"NCCL_IB_HCA": "mlx5_bond",
|
||||||
|
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||||
|
"NCCL_SET_THREAD_NAME": "1",
|
||||||
|
"NCCL_DEBUG": "WARN",
|
||||||
|
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||||
|
}
|
||||||
|
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_vars(
|
||||||
|
cluster_name: str, additional_env_vars: Optional[str] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Returns the environment variables for the cluster."""
|
||||||
|
_additional_env_vars = (
|
||||||
|
dict(item.split("=") for item in additional_env_vars.split(","))
|
||||||
|
if additional_env_vars
|
||||||
|
else dict()
|
||||||
|
)
|
||||||
|
if cluster_name == "na132":
|
||||||
|
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **_additional_env_vars}
|
||||||
|
else:
|
||||||
|
return {**BASE_ENVIRONS, **_additional_env_vars}
|
||||||
|
|
||||||
|
|
||||||
|
def wait_sglang_server_addrs(
|
||||||
|
experiment_name: str,
|
||||||
|
trial_name: str,
|
||||||
|
n_sglang_servers: int,
|
||||||
|
):
|
||||||
|
# Get SGLang slurm nodes, find the hosts
|
||||||
|
name = names.gen_servers(experiment_name, trial_name)
|
||||||
|
start = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
sglang_addrs = name_resolve.get_subtree(name)
|
||||||
|
if len(sglang_addrs) >= n_sglang_servers:
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(sglang_addrs)} SGLang servers: {', '.join(sglang_addrs)}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Timeout waiting for SGLang servers to be ready. "
|
||||||
|
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||||
|
)
|
||||||
|
return sglang_addrs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_for_distributed_launcher(config):
|
||||||
|
n_nodes = config.cluster.n_nodes
|
||||||
|
n_gpus_per_node = config.cluster.n_gpus_per_node
|
||||||
|
allocation_mode = config.allocation_mode
|
||||||
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
|
assert (
|
||||||
|
allocation_mode.gen_world_size + allocation_mode.train_world_size
|
||||||
|
== n_nodes * n_gpus_per_node
|
||||||
|
), (
|
||||||
|
f"#GPUs required for allocation mode {allocation_mode.gen_world_size + allocation_mode.train_world_size} "
|
||||||
|
f"is not equal to #GPUs in the config {n_nodes * n_gpus_per_node}."
|
||||||
|
)
|
||||||
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||||
|
# Launcher should launch SGLang servers according to allocation mode.
|
||||||
|
assert (
|
||||||
|
allocation_mode.gen_pp_size == 1
|
||||||
|
), "Pipeline generation in SGLang is not supported for now."
|
||||||
|
assert (
|
||||||
|
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
|
||||||
|
), "Currently only support SGLang TP size less <= #GPUs per node."
|
|
@ -1,48 +0,0 @@
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tensordict import TensorDict
|
|
||||||
|
|
||||||
|
|
||||||
def concat_padded_tensors(
|
|
||||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
|
||||||
) -> TensorDict:
|
|
||||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
|
||||||
if not tensor_dicts:
|
|
||||||
return TensorDict()
|
|
||||||
|
|
||||||
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
|
|
||||||
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
|
|
||||||
|
|
||||||
# Find max sequence length across all dictionaries
|
|
||||||
assert all("attention_mask" in td for td in tensor_dicts)
|
|
||||||
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
|
||||||
result = {}
|
|
||||||
# Process each key
|
|
||||||
for key in tensor_dicts[0].keys():
|
|
||||||
tensors_to_concat = []
|
|
||||||
for tensor_dict in tensor_dicts:
|
|
||||||
tensor = tensor_dict[key]
|
|
||||||
# Skip 1D tensors like rewards
|
|
||||||
if len(tensor.shape) == 1:
|
|
||||||
tensors_to_concat.append(tensor)
|
|
||||||
continue
|
|
||||||
current_length = tensor.shape[1]
|
|
||||||
if current_length < max_length:
|
|
||||||
# Pad tensor to max_length
|
|
||||||
pad_width = max_length - current_length
|
|
||||||
if key == "attention_mask":
|
|
||||||
# Pad attention mask with 0s
|
|
||||||
padding = torch.zeros(
|
|
||||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Pad feature tensors with pad_value
|
|
||||||
padding = torch.full(
|
|
||||||
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
|
|
||||||
)
|
|
||||||
tensor = torch.cat([tensor, padding], dim=1)
|
|
||||||
tensors_to_concat.append(tensor)
|
|
||||||
|
|
||||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
|
||||||
return TensorDict(result, batch_size=new_batch_size)
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import ray
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
from arealite.utils.network import find_free_ports, gethostip
|
||||||
|
|
||||||
|
|
||||||
|
def get_placement_group_master_ip_and_port(placement_group: PlacementGroup):
|
||||||
|
def _master_ip_and_port():
|
||||||
|
host_ip = gethostip()
|
||||||
|
port = find_free_ports(1, (10000, 60000))[0]
|
||||||
|
return host_ip, port
|
||||||
|
|
||||||
|
future = ray.remote(
|
||||||
|
num_cpus=1,
|
||||||
|
num_gpus=0,
|
||||||
|
memory=10 * 1024 * 1024, # Convert MB to bytes
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=0,
|
||||||
|
),
|
||||||
|
)(_master_ip_and_port).remote()
|
||||||
|
return ray.get(future)
|
|
@ -0,0 +1,154 @@
|
||||||
|
import subprocess
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
from realhf.base import logging
|
||||||
|
from realhf.scheduler.client import JobInfo, JobState
|
||||||
|
|
||||||
|
logger = logging.getLogger("Slurm Utils")
|
||||||
|
|
||||||
|
|
||||||
|
SQUEUE_FIELDS = [
|
||||||
|
"JobID",
|
||||||
|
"State",
|
||||||
|
"SubmitTime",
|
||||||
|
"StartTime",
|
||||||
|
"Name",
|
||||||
|
"NodeList",
|
||||||
|
"UserName",
|
||||||
|
"MaxCPUs",
|
||||||
|
"cpus-per-task",
|
||||||
|
"NumTasks",
|
||||||
|
"tres-alloc",
|
||||||
|
]
|
||||||
|
STATUS_MAPPING = {
|
||||||
|
"RUNNING": JobState.RUNNING,
|
||||||
|
"COMPLETING": JobState.RUNNING,
|
||||||
|
"PENDING": JobState.PENDING,
|
||||||
|
"CANCELLED": JobState.CANCELLED,
|
||||||
|
"FAILED": JobState.FAILED,
|
||||||
|
"COMPLETED": JobState.COMPLETED,
|
||||||
|
"OUT_OF_MEMORY": JobState.FAILED,
|
||||||
|
"DEADLINE": JobState.COMPLETED,
|
||||||
|
"TIMEOUT": JobState.COMPLETED,
|
||||||
|
}
|
||||||
|
|
||||||
|
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||||
|
{sbatch_options}
|
||||||
|
|
||||||
|
# Getting the node names
|
||||||
|
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||||
|
echo nodes=$nodes
|
||||||
|
|
||||||
|
nodes_array=($nodes)
|
||||||
|
echo node_array=$nodes_array
|
||||||
|
|
||||||
|
head_node=${{nodes_array[0]}}
|
||||||
|
echo head_node=$head_node
|
||||||
|
|
||||||
|
# Getting the head node IP address
|
||||||
|
head_node_ip=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||||
|
echo head_node_ip=$head_node_ip
|
||||||
|
|
||||||
|
# Find a free port on the head node
|
||||||
|
# Wonderful linux command to find a random free port (between 10000 and 60000) by deepseek
|
||||||
|
trainer_port=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" bash -c "comm -23 <(seq 10000 60000 | sort) <(ss -tan | awk '{{print $4}}' | cut -d':' -f2 | grep '[0-9]\\{{1,5\\}}' | sort -u) | shuf | head -n 1")
|
||||||
|
echo trainer_port=$trainer_port
|
||||||
|
|
||||||
|
# srun commands
|
||||||
|
{srun_cmds}
|
||||||
|
|
||||||
|
wait
|
||||||
|
"""
|
||||||
|
|
||||||
|
SRUN_CMD_TEMPLATE: str = """srun {additional_args} \\
|
||||||
|
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
|
||||||
|
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
|
||||||
|
{cmd} &
|
||||||
|
"""
|
||||||
|
|
||||||
|
APPTAINER_CMD_TEMPLATE: str = """singularity exec --no-home --writable-tmpfs --nv --pid \\
|
||||||
|
--bind {container_mounts} \\
|
||||||
|
{container_env_strings} \\
|
||||||
|
{container_image} \\
|
||||||
|
{cmd}"""
|
||||||
|
|
||||||
|
|
||||||
|
def cancel_jobs(
|
||||||
|
slurm_names: Optional[List[str]] = None,
|
||||||
|
slurm_ids: Optional[List[int]] = None,
|
||||||
|
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
slurm_names is not None or slurm_ids is not None
|
||||||
|
), "Must specify slurm_names or slurm_ids."
|
||||||
|
assert not (
|
||||||
|
slurm_names and slurm_ids
|
||||||
|
), "Cannot specify both slurm_names and slurm_ids."
|
||||||
|
cmd = ["scancel", "-s", signal]
|
||||||
|
if slurm_names is not None:
|
||||||
|
cmd += ["-n", ",".join(slurm_names)]
|
||||||
|
elif slurm_ids is not None:
|
||||||
|
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
logger.info(
|
||||||
|
f"Cancelled Slurm job with signal {signal}: "
|
||||||
|
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def query_jobs(
|
||||||
|
slurm_names: Optional[List[str]] = None,
|
||||||
|
slurm_ids: Optional[List[int]] = None,
|
||||||
|
status: str = "all",
|
||||||
|
delimiter: str = "__PSI__",
|
||||||
|
) -> List[JobInfo]:
|
||||||
|
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||||
|
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||||
|
if slurm_names is not None:
|
||||||
|
cmd += ["-n", ",".join(slurm_names)]
|
||||||
|
if slurm_ids is not None:
|
||||||
|
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||||
|
|
||||||
|
output = (
|
||||||
|
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||||
|
)
|
||||||
|
rs = []
|
||||||
|
for line in output.split("\n")[1:]:
|
||||||
|
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||||
|
delimiter
|
||||||
|
)
|
||||||
|
rs.append(
|
||||||
|
JobInfo(
|
||||||
|
name=slurm_name,
|
||||||
|
state=STATUS_MAPPING[state],
|
||||||
|
host=nodelist,
|
||||||
|
submit_time=submit_time,
|
||||||
|
start_time=start_time,
|
||||||
|
slurm_id=int(job_id.strip()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return rs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||||
|
return (
|
||||||
|
subprocess.check_output(
|
||||||
|
[
|
||||||
|
"scontrol",
|
||||||
|
"show",
|
||||||
|
"hostnames",
|
||||||
|
nodelist,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
.split("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_slurm_host_ip(node: str, srun_addtional_args: str):
|
||||||
|
try:
|
||||||
|
cmd = f"srun {srun_addtional_args} --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||||
|
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
logger.warning(f"Get slurm host IP for node {node} failed.")
|
|
@ -73,7 +73,7 @@ class StatsLogger:
|
||||||
)
|
)
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
data = [data]
|
data = [data]
|
||||||
log_step = max(global_step, self._last_commit_step)
|
log_step = max(global_step, self._last_commit_step + 1)
|
||||||
for i, item in enumerate(data):
|
for i, item in enumerate(data):
|
||||||
self.info(f"Stats ({i+1}/{len(data)}):")
|
self.info(f"Stats ({i+1}/{len(data)}):")
|
||||||
self.print_stats(item)
|
self.print_stats(item)
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import colorama
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import GenerationHyperparameters
|
from arealite.api.cli_args import GenerationHyperparameters
|
||||||
|
from arealite.api.engine_api import InferenceEngine
|
||||||
from arealite.api.io_struct import LLMRequest
|
from arealite.api.io_struct import LLMRequest
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
from arealite.api.workflow_api import RolloutWorkflow
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
@ -18,13 +21,17 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
gconfig: GenerationHyperparameters,
|
gconfig: GenerationHyperparameters,
|
||||||
tokenizer: PreTrainedTokenizerFast,
|
tokenizer: PreTrainedTokenizerFast,
|
||||||
enable_thinking: bool,
|
enable_thinking: bool,
|
||||||
|
dump_dir: str | None = None,
|
||||||
):
|
):
|
||||||
self.reward_fn = reward_fn
|
self.reward_fn = reward_fn
|
||||||
self.gconfig = gconfig
|
self.gconfig = gconfig
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.enable_thinking = enable_thinking
|
self.enable_thinking = enable_thinking
|
||||||
|
self.dump_dir = dump_dir
|
||||||
|
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||||
|
os.makedirs(self.dump_dir, exist_ok=True)
|
||||||
|
|
||||||
async def arun_episode(self, engine, data):
|
async def arun_episode(self, engine: InferenceEngine, data):
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
data["messages"],
|
data["messages"],
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
|
@ -39,6 +46,12 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
)
|
)
|
||||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||||
|
|
||||||
|
version = engine.get_version()
|
||||||
|
prompt_strs = []
|
||||||
|
completions_strs = []
|
||||||
|
rewards = []
|
||||||
|
seqlens = []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for resp in resps:
|
for resp in resps:
|
||||||
seq = resp.input_tokens + resp.output_tokens
|
seq = resp.input_tokens + resp.output_tokens
|
||||||
|
@ -46,13 +59,19 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||||
versions = [-1] * resp.input_len + resp.output_versions
|
versions = [-1] * resp.input_len + resp.output_versions
|
||||||
|
|
||||||
|
prompt_str = self.tokenizer.decode(input_ids)
|
||||||
|
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||||
|
prompt_strs.append(prompt_str)
|
||||||
|
completions_strs.append(completions_str)
|
||||||
|
seqlens.append(len(seq))
|
||||||
reward = self.reward_fn(
|
reward = self.reward_fn(
|
||||||
prompt=self.tokenizer.decode(input_ids),
|
prompt=prompt_str,
|
||||||
completions=self.tokenizer.decode(resp.output_tokens),
|
completions=completions_str,
|
||||||
prompt_ids=resp.input_tokens,
|
prompt_ids=resp.input_tokens,
|
||||||
completion_ids=resp.output_tokens,
|
completion_ids=resp.output_tokens,
|
||||||
**data,
|
**data,
|
||||||
)
|
)
|
||||||
|
rewards.append(reward)
|
||||||
res = dict(
|
res = dict(
|
||||||
# unsqueeze to add an additional batch dimension
|
# unsqueeze to add an additional batch dimension
|
||||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||||
|
@ -65,4 +84,31 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
)
|
)
|
||||||
results.append(TensorDict(res, batch_size=[1]))
|
results.append(TensorDict(res, batch_size=[1]))
|
||||||
|
|
||||||
|
if self.dump_dir is not None:
|
||||||
|
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||||
|
# Get the unique identifier for this prompt
|
||||||
|
qid = None
|
||||||
|
for key in ["query_id", "id", "qid"]:
|
||||||
|
qid = data.get(key, None)
|
||||||
|
if qid is not None:
|
||||||
|
break
|
||||||
|
qid = qid or uuid.uuid4().hex
|
||||||
|
|
||||||
|
# Dump rollout to file
|
||||||
|
with open(
|
||||||
|
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||||
|
) as f:
|
||||||
|
n_samples = self.gconfig.n_samples
|
||||||
|
for i, (p, c, r, sl) in enumerate(
|
||||||
|
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||||
|
):
|
||||||
|
info = "\n".join(
|
||||||
|
[
|
||||||
|
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||||
|
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||||
|
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
f.write(info + "\n")
|
||||||
|
|
||||||
return concat_padded_tensors(results)
|
return concat_padded_tensors(results)
|
||||||
|
|
|
@ -0,0 +1,310 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import colorama
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from datasets import load_dataset
|
||||||
|
from datasets.distributed import split_dataset_by_node
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from arealite.api.cli_args import (
|
||||||
|
GenerationHyperparameters,
|
||||||
|
GRPOConfig,
|
||||||
|
load_expr_config,
|
||||||
|
)
|
||||||
|
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
|
||||||
|
from arealite.api.workflow_api import RolloutWorkflow
|
||||||
|
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||||
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
from arealite.utils.device import log_gpu_stats
|
||||||
|
from arealite.utils.saver import Saver
|
||||||
|
from arealite.utils.stats_logger import StatsLogger
|
||||||
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
|
from realhf.base import logging, seeding, stats_tracker
|
||||||
|
|
||||||
|
logger = logging.getLogger("boba math")
|
||||||
|
|
||||||
|
|
||||||
|
class RLVRWorkflow(RolloutWorkflow):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reward_fn,
|
||||||
|
gconfig: GenerationHyperparameters,
|
||||||
|
tokenizer: PreTrainedTokenizerFast,
|
||||||
|
dump_dir: str | None = None,
|
||||||
|
):
|
||||||
|
self.reward_fn = reward_fn
|
||||||
|
self.gconfig = gconfig
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.dump_dir = dump_dir
|
||||||
|
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||||
|
os.makedirs(self.dump_dir, exist_ok=True)
|
||||||
|
|
||||||
|
async def arun_episode(self, engine, data):
|
||||||
|
input_ids = self.tokenizer.encode(data["prompt"])
|
||||||
|
n_samples = self.gconfig.n_samples
|
||||||
|
req = LLMRequest(
|
||||||
|
rid=uuid.uuid4().hex,
|
||||||
|
input_ids=input_ids,
|
||||||
|
gconfig=self.gconfig.new(n_samples=1),
|
||||||
|
)
|
||||||
|
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||||
|
|
||||||
|
version = engine.get_version()
|
||||||
|
prompt_strs = []
|
||||||
|
completions_strs = []
|
||||||
|
rewards = []
|
||||||
|
seqlens = []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for resp in resps:
|
||||||
|
seq = resp.input_tokens + resp.output_tokens
|
||||||
|
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||||
|
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||||
|
versions = [-1] * resp.input_len + resp.output_versions
|
||||||
|
|
||||||
|
prompt_str = data["prompt"]
|
||||||
|
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||||
|
prompt_strs.append(prompt_str)
|
||||||
|
completions_strs.append(completions_str)
|
||||||
|
seqlens.append(len(seq))
|
||||||
|
reward = self.reward_fn(
|
||||||
|
completions=completions_str,
|
||||||
|
prompt_ids=resp.input_tokens,
|
||||||
|
completion_ids=resp.output_tokens,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
rewards.append(reward)
|
||||||
|
res = dict(
|
||||||
|
# unsqueeze to add an additional batch dimension
|
||||||
|
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||||
|
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||||
|
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||||
|
versions=torch.tensor(versions).unsqueeze(0),
|
||||||
|
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||||
|
# reward
|
||||||
|
rewards=torch.tensor([float(reward)]),
|
||||||
|
)
|
||||||
|
results.append(TensorDict(res, batch_size=[1]))
|
||||||
|
|
||||||
|
if self.dump_dir is not None:
|
||||||
|
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||||
|
# Get the unique identifier for this prompt
|
||||||
|
qid = None
|
||||||
|
for key in ["query_id", "id", "qid"]:
|
||||||
|
qid = data.get(key, None)
|
||||||
|
if qid is not None:
|
||||||
|
break
|
||||||
|
qid = qid or uuid.uuid4().hex
|
||||||
|
|
||||||
|
# Dump rollout to file
|
||||||
|
with open(
|
||||||
|
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||||
|
) as f:
|
||||||
|
n_samples = self.gconfig.n_samples
|
||||||
|
for i, (p, c, r, sl) in enumerate(
|
||||||
|
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||||
|
):
|
||||||
|
info = "\n".join(
|
||||||
|
[
|
||||||
|
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||||
|
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||||
|
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
f.write(info + "\n")
|
||||||
|
|
||||||
|
return concat_padded_tensors(results)
|
||||||
|
|
||||||
|
|
||||||
|
def get_boba_math_dataset(tokenizer, rank, world_size):
|
||||||
|
dataset = load_dataset(
|
||||||
|
path="json",
|
||||||
|
split="train",
|
||||||
|
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
|
||||||
|
)
|
||||||
|
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
|
||||||
|
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
def boba_reward_fn(
|
||||||
|
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
||||||
|
):
|
||||||
|
from pebble import ProcessExpired, ProcessPool
|
||||||
|
|
||||||
|
from realhf.impl.dataset.math_parser import process_results
|
||||||
|
|
||||||
|
jobs = []
|
||||||
|
with ProcessPool(max_workers=1) as executor:
|
||||||
|
for sol in solutions:
|
||||||
|
job = executor.schedule(
|
||||||
|
process_results, args=[completions, sol], timeout=15
|
||||||
|
)
|
||||||
|
jobs.append(job)
|
||||||
|
|
||||||
|
label = 0
|
||||||
|
for job in jobs:
|
||||||
|
try:
|
||||||
|
x = job.result()
|
||||||
|
except TimeoutError:
|
||||||
|
# print("[debug: timeout]")
|
||||||
|
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||||
|
x = (0, "timeout", "timeout")
|
||||||
|
except ProcessExpired as e:
|
||||||
|
logger.warning(f"Process terminated abnormally: {e}")
|
||||||
|
x = (0, "error", "error")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||||
|
x = (0, "error", "error")
|
||||||
|
label = label or x[0]
|
||||||
|
return label
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
config, _ = load_expr_config(args, GRPOConfig)
|
||||||
|
config: GRPOConfig
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
|
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||||
|
|
||||||
|
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||||
|
|
||||||
|
# Create dataset and dataloaders
|
||||||
|
train_dataloader = StatefulDataLoader(
|
||||||
|
get_boba_math_dataset(tokenizer, rank, world_size),
|
||||||
|
batch_size=config.train_dataset.batch_size // world_size,
|
||||||
|
shuffle=config.train_dataset.shuffle,
|
||||||
|
num_workers=config.train_dataset.num_workers,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
|
drop_last=config.train_dataset.drop_last,
|
||||||
|
)
|
||||||
|
ft_spec = FinetuneSpec(
|
||||||
|
total_train_epochs=config.total_train_epochs,
|
||||||
|
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||||
|
train_batch_size=config.train_dataset.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize inference engine
|
||||||
|
rollout = RemoteSGLangEngine(config.rollout)
|
||||||
|
rollout.initialize(None, ft_spec)
|
||||||
|
|
||||||
|
# Initialize train engine
|
||||||
|
actor = FSDPPPOActor(config=config.actor)
|
||||||
|
actor.initialize(None, ft_spec)
|
||||||
|
ref = None
|
||||||
|
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||||
|
ref = FSDPPPOActor(config=config.ref)
|
||||||
|
ref.initialize(None, ft_spec)
|
||||||
|
|
||||||
|
# Create rollout workflow
|
||||||
|
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||||
|
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||||
|
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||||
|
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||||
|
workflow = RLVRWorkflow(
|
||||||
|
reward_fn=boba_reward_fn,
|
||||||
|
gconfig=config.gconfig,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dump_dir=os.path.join(
|
||||||
|
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run training.
|
||||||
|
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||||
|
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||||
|
|
||||||
|
total_epochs = config.total_train_epochs
|
||||||
|
steps_per_epoch = len(train_dataloader)
|
||||||
|
max_steps = total_epochs * steps_per_epoch
|
||||||
|
|
||||||
|
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||||
|
data_generator = iter(train_dataloader)
|
||||||
|
for global_step in range(max_steps):
|
||||||
|
epoch = global_step // steps_per_epoch
|
||||||
|
step = global_step % steps_per_epoch
|
||||||
|
|
||||||
|
with stats_tracker.record_timing("rollout"):
|
||||||
|
if config.async_training:
|
||||||
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
data = next(data_generator)
|
||||||
|
except StopIteration:
|
||||||
|
data_generator = iter(train_dataloader)
|
||||||
|
data = next(data_generator)
|
||||||
|
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||||
|
|
||||||
|
batch = batch.to(actor.device)
|
||||||
|
# Create barrier to synchronize all rollout processes.
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||||
|
with stats_tracker.record_timing("recompute_logp"):
|
||||||
|
logp = actor.compute_logp(batch)
|
||||||
|
batch["prox_logp"] = logp
|
||||||
|
log_gpu_stats("recompute logp")
|
||||||
|
|
||||||
|
if ref is not None:
|
||||||
|
with stats_tracker.record_timing("ref_logp"):
|
||||||
|
batch["ref_logp"] = ref.compute_logp(batch)
|
||||||
|
log_gpu_stats("ref logp")
|
||||||
|
|
||||||
|
with stats_tracker.record_timing("compute_advantage"):
|
||||||
|
actor.compute_advantages(batch)
|
||||||
|
log_gpu_stats("compute advantages")
|
||||||
|
|
||||||
|
with (
|
||||||
|
stats_tracker.record_timing("train_step"),
|
||||||
|
stats_tracker.scope("grpo_actor"),
|
||||||
|
):
|
||||||
|
stats = actor.ppo_update(batch)
|
||||||
|
actor.step_lr_scheduler()
|
||||||
|
log_gpu_stats("ppo update")
|
||||||
|
|
||||||
|
with stats_tracker.record_timing("update_weights"):
|
||||||
|
path = os.path.join(
|
||||||
|
Saver.get_save_checkpoint_root(config.saver),
|
||||||
|
"update_weights",
|
||||||
|
str(global_step + 1),
|
||||||
|
)
|
||||||
|
meta = WeightUpdateMeta(
|
||||||
|
type="disk",
|
||||||
|
path=path,
|
||||||
|
alloc_mode=None,
|
||||||
|
comm_backend=None,
|
||||||
|
model_version=global_step + 1,
|
||||||
|
)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
future = rollout.update_weights(meta)
|
||||||
|
actor.upload_weights(meta)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
future.result()
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
rollout.set_version(global_step + 1)
|
||||||
|
|
||||||
|
with stats_tracker.record_timing("save"):
|
||||||
|
saver.save(actor, epoch, step, global_step)
|
||||||
|
|
||||||
|
logger.commit(epoch, step, global_step, stats)
|
||||||
|
|
||||||
|
logger.close()
|
||||||
|
rollout.destroy()
|
||||||
|
if ref is not None:
|
||||||
|
ref.destroy()
|
||||||
|
actor.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(sys.argv[1:])
|
|
@ -0,0 +1,141 @@
|
||||||
|
experiment_name: lite-boba-math
|
||||||
|
trial_name: run1
|
||||||
|
|
||||||
|
cluster:
|
||||||
|
n_nodes: 16
|
||||||
|
n_gpus_per_node: 8
|
||||||
|
cluster_name: na132
|
||||||
|
fileroot: /storage/openpsi/experiments
|
||||||
|
name_resolve:
|
||||||
|
type: nfs
|
||||||
|
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
|
||||||
|
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
total_train_epochs: 10
|
||||||
|
total_train_steps: null
|
||||||
|
tokenizer_path: ${actor.path}
|
||||||
|
allocation_mode: sglang.d96p1t1+d32p1t1
|
||||||
|
async_training: true
|
||||||
|
|
||||||
|
rollout:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
max_concurrent_rollouts: 400
|
||||||
|
queue_size: null
|
||||||
|
consumer_batch_size: ${train_dataset.batch_size}
|
||||||
|
max_head_offpolicyness: 4
|
||||||
|
enable_rollout_tracing: true
|
||||||
|
|
||||||
|
gconfig:
|
||||||
|
n_samples: 16
|
||||||
|
min_new_tokens: 0
|
||||||
|
max_new_tokens: 30720
|
||||||
|
greedy: false
|
||||||
|
temperature: 1.0
|
||||||
|
|
||||||
|
actor:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||||
|
init_from_scratch: false
|
||||||
|
disable_dropout: true
|
||||||
|
gradient_checkpointing: true
|
||||||
|
dtype: bfloat16
|
||||||
|
mb_spec:
|
||||||
|
max_tokens_per_mb: 32768
|
||||||
|
optimizer:
|
||||||
|
type: adam
|
||||||
|
lr: 1e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
eps: 1e-8
|
||||||
|
lr_scheduler_type: constant
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
warmup_steps_proportion: 0.001
|
||||||
|
backend: fsdp
|
||||||
|
|
||||||
|
group_size: ${gconfig.n_samples}
|
||||||
|
group_adv_norm: false
|
||||||
|
eps_clip: 0.4
|
||||||
|
temperature: ${gconfig.temperature}
|
||||||
|
reward_scaling: 10.0
|
||||||
|
reward_bias: -0.5
|
||||||
|
kl_ctl: 0.0
|
||||||
|
ppo_n_minibatches: 4
|
||||||
|
recompute_logprob: true
|
||||||
|
use_decoupled_loss: true
|
||||||
|
behav_imp_weight_cap: 5.0
|
||||||
|
|
||||||
|
ref:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
path: ${actor.path}
|
||||||
|
init_from_scratch: false
|
||||||
|
disable_dropout: true
|
||||||
|
dtype: ${actor.dtype}
|
||||||
|
mb_spec:
|
||||||
|
max_tokens_per_mb: 32768
|
||||||
|
optimizer: null
|
||||||
|
backend: fsdp
|
||||||
|
|
||||||
|
# SGLang
|
||||||
|
server_only: false
|
||||||
|
sglang:
|
||||||
|
model_path: ${actor.path}
|
||||||
|
random_seed: ${seed}
|
||||||
|
skip_tokenizer_init: true
|
||||||
|
dtype: ${actor.dtype}
|
||||||
|
max_running_requests: null
|
||||||
|
context_length: 32768
|
||||||
|
mem_fraction_static: 0.9
|
||||||
|
|
||||||
|
# datasets
|
||||||
|
train_dataset:
|
||||||
|
batch_size: 512
|
||||||
|
shuffle: true
|
||||||
|
pin_memory: true
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
saver:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
fileroot: ${cluster.fileroot}
|
||||||
|
freq_epochs: 1
|
||||||
|
freq_steps: null
|
||||||
|
freq_secs: null
|
||||||
|
|
||||||
|
checkpointer:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
fileroot: ${cluster.fileroot}
|
||||||
|
freq_epochs: 1
|
||||||
|
freq_steps: null
|
||||||
|
freq_secs: 3600
|
||||||
|
|
||||||
|
evaluator:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
fileroot: ${cluster.fileroot}
|
||||||
|
freq_epochs: null
|
||||||
|
freq_steps: null
|
||||||
|
freq_secs: null
|
||||||
|
|
||||||
|
stats_logger:
|
||||||
|
experiment_name: ${experiment_name}
|
||||||
|
trial_name: ${trial_name}
|
||||||
|
fileroot: ${cluster.fileroot}
|
||||||
|
wandb:
|
||||||
|
mode: online
|
||||||
|
|
||||||
|
# Launcher
|
||||||
|
launcher:
|
||||||
|
inference_server_cpus_per_gpu: 15
|
||||||
|
inference_server_mem_per_gpu: 153600
|
||||||
|
trainer_cpus_per_gpu: 15
|
||||||
|
trainer_mem_per_gpu: 153600
|
||||||
|
slurm:
|
||||||
|
mount: /storage:/storage
|
||||||
|
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
||||||
|
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
|
@ -41,7 +41,7 @@ actor:
|
||||||
max_tokens_per_mb: 10240
|
max_tokens_per_mb: 10240
|
||||||
optimizer:
|
optimizer:
|
||||||
type: adam
|
type: adam
|
||||||
lr: 2e-6
|
lr: 1e-5
|
||||||
weight_decay: 0.01
|
weight_decay: 0.01
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
beta2: 0.999
|
beta2: 0.999
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -18,7 +18,9 @@ from arealite.utils.saver import Saver
|
||||||
from arealite.utils.stats_logger import StatsLogger
|
from arealite.utils.stats_logger import StatsLogger
|
||||||
from arealite.workflow.rlvr import RLVRWorkflow
|
from arealite.workflow.rlvr import RLVRWorkflow
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
from realhf.base import stats_tracker
|
from realhf.base import logging, seeding, stats_tracker
|
||||||
|
|
||||||
|
logger = logging.getLogger("GSM8K grpo")
|
||||||
|
|
||||||
|
|
||||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||||
|
@ -36,54 +38,22 @@ def get_gsm8k_dataset(split, rank, world_size):
|
||||||
return process_gsm8k_rl_dataset(dataset)
|
return process_gsm8k_rl_dataset(dataset)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from verl.
|
|
||||||
def extract_solution(solution_str, method="strict") -> str | None:
|
|
||||||
assert method in ["strict", "flexible"]
|
|
||||||
|
|
||||||
final_answer = None
|
|
||||||
if method == "strict":
|
|
||||||
# this also tests the formatting of the model
|
|
||||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
|
||||||
if len(solutions) == 0:
|
|
||||||
final_answer = None
|
|
||||||
else:
|
|
||||||
# take the last solution
|
|
||||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
|
||||||
elif method == "flexible":
|
|
||||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
|
||||||
final_answer = None
|
|
||||||
if len(answer) == 0:
|
|
||||||
# no reward is there is no answer
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
invalid_str = ["", "."]
|
|
||||||
# find the last number that is not '.'
|
|
||||||
for final_answer in reversed(answer):
|
|
||||||
if final_answer not in invalid_str:
|
|
||||||
break
|
|
||||||
return final_answer
|
|
||||||
|
|
||||||
|
|
||||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||||
from realhf.impl.dataset.math_parser import extract_answer
|
from realhf.impl.dataset.math_parser import process_results
|
||||||
|
|
||||||
sol = extract_answer(completions, data_name="math")
|
return int(process_results(completions, answer)[0])
|
||||||
ans = extract_solution(solution_str=answer, method="strict")
|
|
||||||
if sol is None:
|
|
||||||
return 0
|
|
||||||
if ans is None:
|
|
||||||
return 0
|
|
||||||
return int(sol.strip() == ans.strip())
|
|
||||||
|
|
||||||
|
|
||||||
def main_grpo():
|
def main(args):
|
||||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
config, _ = load_expr_config(args, GRPOConfig)
|
||||||
config: GRPOConfig
|
config: GRPOConfig
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
world_size = int(os.getenv("WORLD_SIZE"))
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||||
|
|
||||||
|
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||||
|
|
||||||
# Create dataset and dataloaders
|
# Create dataset and dataloaders
|
||||||
train_dataloader = StatefulDataLoader(
|
train_dataloader = StatefulDataLoader(
|
||||||
get_gsm8k_dataset("train", rank, world_size),
|
get_gsm8k_dataset("train", rank, world_size),
|
||||||
|
@ -133,6 +103,9 @@ def main_grpo():
|
||||||
gconfig=config.gconfig,
|
gconfig=config.gconfig,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
enable_thinking=False,
|
enable_thinking=False,
|
||||||
|
dump_dir=os.path.join(
|
||||||
|
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run training.
|
# Run training.
|
||||||
|
@ -190,13 +163,14 @@ def main_grpo():
|
||||||
log_gpu_stats("ppo update")
|
log_gpu_stats("ppo update")
|
||||||
|
|
||||||
with stats_tracker.record_timing("update_weights"):
|
with stats_tracker.record_timing("update_weights"):
|
||||||
|
path = os.path.join(
|
||||||
|
Saver.get_save_checkpoint_root(config.saver),
|
||||||
|
"update_weights",
|
||||||
|
str(global_step + 1),
|
||||||
|
)
|
||||||
meta = WeightUpdateMeta(
|
meta = WeightUpdateMeta(
|
||||||
type="disk",
|
type="disk",
|
||||||
path=os.path.join(
|
path=path,
|
||||||
Saver.get_save_checkpoint_root(config.saver),
|
|
||||||
"update_weights",
|
|
||||||
str(global_step),
|
|
||||||
),
|
|
||||||
alloc_mode=None,
|
alloc_mode=None,
|
||||||
comm_backend=None,
|
comm_backend=None,
|
||||||
model_version=global_step + 1,
|
model_version=global_step + 1,
|
||||||
|
@ -206,6 +180,7 @@ def main_grpo():
|
||||||
actor.upload_weights(meta)
|
actor.upload_weights(meta)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future.result()
|
future.result()
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
rollout.set_version(global_step + 1)
|
rollout.set_version(global_step + 1)
|
||||||
|
@ -253,4 +228,4 @@ def main_grpo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_grpo()
|
main(sys.argv[1:])
|
||||||
|
|
|
@ -35,8 +35,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
def main_sft():
|
def main(args):
|
||||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
config, _ = load_expr_config(args, SFTConfig)
|
||||||
config: SFTConfig
|
config: SFTConfig
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
|
@ -121,4 +121,4 @@ def main_sft():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_sft()
|
main(sys.argv[1:])
|
||||||
|
|
Loading…
Reference in New Issue