diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 906f8fd..3b2afa5 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -676,20 +676,6 @@ class ClusterSpecConfig: "help": "Root for logs and checkpoints. Should be available to all nodes." }, ) - gpu_type: str = field( - default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."} - ) - mount: str = field( - default="/storage:/storage", metadata={"help": "Mount path for slurm."} - ) - gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."}) - cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."}) - gpu_infer_image: str = field( - default="", metadata={"help": "slurm image for LLM inference."} - ) - node_name_prefix: str = field( - default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."} - ) n_nodes: int = field( default=32, metadata={ @@ -725,6 +711,72 @@ class DatasetConfig: drop_last: bool = field(default=True) +@dataclass +class SlurmLauncherConfig: + """Configuration for launching the SGLang server with Slurm.""" + + srun_additional_args: str = field( + default="--overlap --mpi=pmi2 -K --chdir $PWD", + metadata={"help": "Additional arguments to pass to the srun command."}, + ) + container_type: str = field( + default="apptainer", + metadata={ + "help": "Type of containers used in slurm", + "choices": ["apptainer", "none"], + }, + ) + mount: str = field( + default="/storage:/storage", metadata={"help": "Mount path for slurm."} + ) + trainer_image: str = field( + default="", metadata={"help": "slurm image for trainers."} + ) + inference_server_image: str = field( + default="", metadata={"help": "slurm image for LLM inference."} + ) + + +@dataclass +class LauncherConfig: + """Configuration for launching the SGLang server.""" + + inference_server_cpus_per_gpu: int = field( + default=4, + metadata={"help": "Number of CPUs allocated per GPU for inference server. "}, + ) + inference_server_mem_per_gpu: int = field( + default=32 * 1024, + metadata={"help": "Memory allocated per GPU for inference server in MB. "}, + ) + trainer_cpus_per_gpu: int = field( + default=4, + metadata={"help": "Number of CPUs allocated per GPU for training. "}, + ) + trainer_mem_per_gpu: int = field( + default=32 * 1024, + metadata={"help": "Memory allocated per GPU for training in MB. "}, + ) + inference_server_env_vars: str = field( + default="", + metadata={ + "help": "Environment variables for inference server, seperated by commas. " + "Example: 'ENV1=val1,ENV2=val2'. " + }, + ) + trainer_env_vars: str = field( + default="", + metadata={ + "help": "Environment variables for training, seperated by commas. " + "Example: 'ENV1=val1,ENV2=val2'. " + }, + ) + slurm: SlurmLauncherConfig = field( + default_factory=SlurmLauncherConfig, + metadata={"help": "Slurm launcher configuration."}, + ) + + @dataclass class BaseExperimentConfig: # NOTE: we need this unified config class because different experiments @@ -742,12 +794,6 @@ class BaseExperimentConfig: default_factory=ClusterSpecConfig, metadata={"help": "Cluster specification. Mainly used by slurm."}, ) - n_nodes: int = field( - default=1, metadata={"help": "Number of nodes for experiment."} - ) - n_gpus_per_node: int = field( - default=8, metadata={"help": "Number of GPUs per node for this experiment."} - ) allocation_mode: str = field( default="", metadata={ @@ -785,6 +831,7 @@ class BaseExperimentConfig: server_only: bool = False sglang: SGLangConfig = field(default_factory=SGLangConfig) + launcher: LauncherConfig = field(default_factory=LauncherConfig) @dataclass diff --git a/arealite/engine/base_hf_engine.py b/arealite/engine/base_hf_engine.py index 1adc8d4..8d0b1a8 100644 --- a/arealite/engine/base_hf_engine.py +++ b/arealite/engine/base_hf_engine.py @@ -266,6 +266,7 @@ class BaseHFEngine(TrainEngine): # Scale loss for accumulation # Revert gradient averaging across dp ranks + # FIXME: should be DP size loss_scale *= self.world_size loss *= loss_scale @@ -286,8 +287,6 @@ class BaseHFEngine(TrainEngine): update_successful = True current_lr = self.lr_scheduler.get_last_lr()[0] - # Optimizer step - self.optimizer.step() return dict( update_successful=float(update_successful), grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), diff --git a/arealite/engine/fsdp_engine.py b/arealite/engine/fsdp_engine.py index d785fda..a0926c3 100644 --- a/arealite/engine/fsdp_engine.py +++ b/arealite/engine/fsdp_engine.py @@ -1,10 +1,7 @@ -import dis -import gc import os -import threading import time from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed as dist @@ -14,14 +11,7 @@ from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - PreTrainedTokenizerFast, - get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, -) +from transformers import PreTrainedTokenizerFast from arealite.api.cli_args import TrainEngineConfig from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta @@ -232,6 +222,7 @@ class FSDPEngine(BaseHFEngine): for i, (pad_length, padded_mb_input, mb_input) in enumerate( zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs) ): + self.model.set_requires_gradient_sync(i == len(mb_list.mbs) - 1) outputs = self.model(**padded_mb_input) logits = outputs.logits.squeeze(0) @@ -258,8 +249,6 @@ class FSDPEngine(BaseHFEngine): update_successful = True current_lr = self.lr_scheduler.get_last_lr()[0] - # Optimizer step - self.optimizer.step() return dict( update_successful=float(update_successful), grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), diff --git a/arealite/engine/sft/lm_engine.py b/arealite/engine/sft/lm_engine.py index 18bee78..eaa8fc3 100644 --- a/arealite/engine/sft/lm_engine.py +++ b/arealite/engine/sft/lm_engine.py @@ -58,15 +58,9 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64 ) for i in range(cu_seqlens.shape[0] - 1): - m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1] - logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1] - assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], ( - cu_seqlens, - logprobs.shape, - ) - seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / ( - m.numel() - m.count_nonzero() - ) + m = loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]] + logp = logprobs[cu_seqlens[i] : cu_seqlens[i + 1]] + seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (m.count_nonzero()) ## Loggin stats stats_tracker.denominator( diff --git a/arealite/launcher/ray.py b/arealite/launcher/ray.py new file mode 100644 index 0000000..0344e98 --- /dev/null +++ b/arealite/launcher/ray.py @@ -0,0 +1,410 @@ +import getpass +import importlib.util +import os +import pathlib +import sys +import time +from typing import Dict, List, Optional + +import ray +import ray.exceptions +from ray.runtime_env import RuntimeEnv +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +import realhf.base.logging as logging +from arealite.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + SGLangConfig, + parse_cli_args, + to_structured_cfg, +) +from arealite.api.io_struct import AllocationMode, AllocationType +from arealite.utils.launcher import ( + get_env_vars, + validate_config_for_distributed_launcher, + wait_sglang_server_addrs, +) +from arealite.utils.ray import get_placement_group_master_ip_and_port +from realhf.base import logging, name_resolve, names +from realhf.scheduler.client import JobException, JobState + +logger = logging.getLogger("RayLauncher") + +RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds +DEFAULT_MAIN_FUNC_NAME = "main" + + +def run_func(file_path, function_name, *args, **kwargs): + # Convert the file path to a module name + module_name = file_path.replace("/", "_").replace(".", "_") + + # Load the module from file path + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Get the function and execute it + try: + function = getattr(module, function_name) + except AttributeError as e: + raise ValueError( + f"Function '{function_name}' not found in module '{module_name}'. " + f"Please ensure the name of the main function in your entry point " + f"is '{function_name}'." + ) from e + return function(*args, **kwargs) + + +class RayLauncher: + def __init__(self, experiment_name: str, trial_name: str, fileroot: str): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.fileroot = fileroot + + # job_name to ray future + self.jobs = {} + + @property + def run_name(self): + return f"{self.experiment_name}_{self.trial_name}" + + def log_path_of(self, job_name: str) -> str: + log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(log_path, exist_ok=True) + return os.path.join(log_path, f"{job_name}.log") + + def submit( + self, + job_name: str, + file_path: str, + func_name: str, + args: List[str], # arguments to pass to the function + gpus: int, + cpus: int, + mem: int, # MB + env_vars: Optional[Dict] = None, + placement_group: Optional[PlacementGroup] = None, + bundle_index: int = -1, + kwargs: Optional[ + Dict[str, str] + ] = None, # keyword arguments to pass to the function + ): + if kwargs is None: + kwargs = {} + runtime_env = RuntimeEnv( + env_vars=env_vars or dict(), + ) + scheduling_strategy = ( + PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=bundle_index, + placement_group_capture_child_tasks=True, + ) + if placement_group is not None + else "DEFAULT" + ) + future = ray.remote( + num_cpus=cpus, + num_gpus=gpus, + memory=mem * 1024 * 1024, # Convert MB to bytes + runtime_env=runtime_env, + scheduling_strategy=scheduling_strategy, + )(run_func).remote(file_path, func_name, *args, **kwargs) + self.jobs[job_name] = future + return future + + def submit_array( + self, + job_name: str, + file_path: str, + func_name: str, + count: int, + nodes: int, + list_args: List[List], + gpus_per_task: int, + cpus_per_task: int, + mem_per_task: int, # MB + list_kwargs: List[Dict] | None = None, + env_vars: Optional[Dict] = None, + amend_torch_dist_env: bool = False, + ): + """Submit an array of jobs to Ray with ray placement groups. + + Note: Here we use `ray.remote` instead of `ray job submit` since `ray job submit` + does not support placement groups, and can not specify which node to run the job on. + Therefore we could not know the IP address of jobs for torch distributed initialization. + """ + + if count % nodes != 0: + raise ValueError( + f"Count {count} is not divisible by nodes {nodes}. " + "Please ensure that count is a multiple of nodes." + ) + assert ( + len(list_args) == count + ), f"Length of list_args {len(list_args)} does not match count {count}." + if list_kwargs is not None: + assert ( + len(list_kwargs) == count + ), f"Length of list_kwargs {len(list_kwargs)} does not match count {count}." + + tasks_per_node = count // nodes + gpus_per_node = gpus_per_task * tasks_per_node + cpus_per_node = cpus_per_task * tasks_per_node + mem_per_node = mem_per_task * tasks_per_node + + placement_group = ray.util.placement_group( + bundles=[ + { + "CPU": cpus_per_node, + "GPU": gpus_per_node, + "memory": mem_per_node * 1024 * 1024, # Convert MB to bytes + } + ] + * nodes, + strategy="STRICT_SPREAD", + ) + try: + ray.get(placement_group.ready(), timeout=30) + except ray.exceptions.GetTimeoutError as e: + logger.error( + "Ray placement group timeout, please check if the resource requirement " + "for your experiment exceeds the available resources in the cluster. \n" + f"ray.nodes(): {ray.nodes()} \n" + f"Placement Group bundles: " + f"cpus_per_node={cpus_per_node}, gpus_per_node={gpus_per_node}, " + f"mem_per_node={mem_per_node}MB, nodes={nodes}" + ) + raise e + + if amend_torch_dist_env: + host_ip, port = get_placement_group_master_ip_and_port(placement_group) + logger.info( + f"Amend torch distributed env vars: " + f"MASTER_ADDR={host_ip}, PORT={port}" + ) + + futures = [] + for i in range(count): + args = list_args[i] + kwargs = list_kwargs[i] if list_kwargs is not None else {} + + # manage environment variables + env_vars = env_vars or {} + if "CUDA_VISIBLE_DEVICES" in env_vars: + logger.warning( + "Setting CUDA_VISIBLE_DEVICES before running ray jobs may result in unexpected behavior." + ) + + node_id = i // tasks_per_node + _env_vars = { + **env_vars, + } + + if amend_torch_dist_env: + assert gpus_per_task == 1 + # NOTE: Here we only provide environment variables for torch distributed + # initialization, and LOCAL_RANK for torch.device. + # Other environment variables automatically set by torchrun are not set, and + # they should be never accessed in trainer code. + _env_vars.update( + { + "RANK": str(i), + "WORLD_SIZE": str(count), + # Ray will automatically isolate CUDA_VISIBLE_DEVICES for each GPU + "LOCAL_RANK": "0", + "MASTER_ADDR": str(host_ip), + "MASTER_PORT": str(port), + } + ) + future = self.submit( + job_name=f"{job_name}:{i}", + file_path=file_path, + func_name=func_name, + args=args, + gpus=gpus_per_task, + cpus=cpus_per_task, + mem=mem_per_task, + env_vars=_env_vars, + placement_group=placement_group, + bundle_index=node_id, + kwargs=kwargs, + ) + futures.append(future) + + return futures + + def stop(self, job_name: str, force: bool = False): + """Stop a job by name.""" + if job_name in self.jobs: + future = self.jobs[job_name] + try: + ray.cancel(future, force=force) + except Exception as e: + logger.error(f"Failed to cancel job {job_name}: {e}") + return + self.jobs.pop(job_name, None) + logger.info(f"Job {job_name} stopped.") + else: + logger.warning(f"Job {job_name} not found in running jobs.") + + def stop_all(self, force: bool = False): + """Stop all jobs.""" + for job_name in list(self.jobs.keys()): + self.stop(job_name, force=force) + logger.info("All jobs stopped.") + self.jobs.clear() + + def wait( + self, check_status=(JobState.FAILED,), remove_status=(JobState.COMPLETED,) + ): + """Check every RAY_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs. + If a ray job returns, its status changes to JobState.COMPLETED. + If a ray job failed, its status changes to JobState.FAILED. + If any job is in check_status, stop all jobs at once. + If any job is in remove status, remove them from job list. + Return if all jobs are removed from job list, or some job is in check status. + """ + for status in list(check_status) + list(remove_status): + assert status in [ + JobState.COMPLETED, + JobState.FAILED, + ], "In RayLauncher.wait, we only check completed or failed jobs." + logger.info(f"Waiting for {len(self.jobs)} jobs.") + while self.jobs: + job_status = {} + for job_name, future in list(self.jobs.items()): + try: + r = ray.get(future, timeout=0.1) + logger.info(f"Job {job_name} completed with result: {r}") + job_status[job_name] = JobState.COMPLETED + except ray.exceptions.RayTaskError as e: + logger.error(f"Job {job_name} failed with error: {e}.") + job_status[job_name] = JobState.FAILED + except ray.exceptions.GetTimeoutError: + continue + + for job_name, status in job_status.items(): + if status in check_status: + logger.info(f"Job {job_name} is {status}, stopping all jobs.") + self.stop_all(force=True) + return + if status in remove_status: + logger.info(f"Job {job_name} is {status}, removed.") + self.jobs.pop(job_name) + + time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL) + + +def ray_main(): + # usage: python -m arealite.launcher.ray --config [] + ray.init() + config, config_file = parse_cli_args(sys.argv[2:]) + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + validate_config_for_distributed_launcher(config) + + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + + n_nodes = config.cluster.n_nodes + n_gpus_per_node = config.cluster.n_gpus_per_node + launcher = RayLauncher( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + fileroot=config.cluster.fileroot, + ) + allocation_mode = config.allocation_mode + allocation_mode = AllocationMode.from_str(allocation_mode) + sglang_cmds = [] + sglang_addrs = [] + n_sglang_nodes = 0 + if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: + # Launcher should launch SGLang servers according to allocation mode. + sglang_tp_size = allocation_mode.gen_tp_size + n_sglang_servers = allocation_mode.gen_dp_size + n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node + + base_seed = config.sglang.random_seed + sglang_args_list = [ + [sys.argv[2:] + [f"sglang.random_seed={base_seed + i}"]] + for i in range(n_sglang_servers) + ] + sglang_entry_point = str( + pathlib.Path(__file__).resolve().parent.joinpath("sglang_server.py") + ) + launcher.submit_array( + job_name="llm_server", + file_path=sglang_entry_point, + func_name=DEFAULT_MAIN_FUNC_NAME, + count=n_sglang_servers, + nodes=n_sglang_nodes, + list_args=sglang_args_list, + gpus_per_task=sglang_tp_size, + cpus_per_task=config.launcher.inference_server_cpus_per_gpu + * sglang_tp_size, + mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + # Get SGLang server addresses via name_resolve + try: + sglang_addrs = wait_sglang_server_addrs( + config.experiment_name, + config.trial_name, + n_sglang_servers, + ) + except TimeoutError as e: + launcher.stop_all(force=True) + raise e + + trainer_n_nodes = n_nodes - n_sglang_nodes + trainer_entry_point = sys.argv[1] + n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node + trainer_args_list = [[sys.argv[2:]] for _ in range(n_trainer_processes)] + if not config.server_only: + # In ray, we launch trainer in the granularity of processes (1 GPU per process) + # We amend environment variable similar to torchrun to ensure correct initialization of + # torch distributed. + launcher.submit_array( + job_name="trainer", + file_path=trainer_entry_point, + func_name=DEFAULT_MAIN_FUNC_NAME, + count=trainer_n_nodes * config.cluster.n_gpus_per_node, + nodes=trainer_n_nodes, + list_args=trainer_args_list, + gpus_per_task=1, + cpus_per_task=config.launcher.trainer_cpus_per_gpu, + mem_per_task=config.launcher.trainer_mem_per_gpu, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs), + ), + amend_torch_dist_env=True, + ) + + try: + launcher.wait(check_status=(JobState.COMPLETED, JobState.FAILED)) + except (KeyboardInterrupt, JobException, TimeoutError) as e: + launcher.stop_all(force=True) + raise e + + +if __name__ == "__main__": + # usage: python -m arealite.launcher.ray \ + # --config [] \ + # launcher.ray.main_func_name= + ray_main() diff --git a/arealite/launcher/sglang_server.py b/arealite/launcher/sglang_server.py new file mode 100644 index 0000000..2e88c04 --- /dev/null +++ b/arealite/launcher/sglang_server.py @@ -0,0 +1,197 @@ +import os +import subprocess +import sys +import time +import uuid +from pathlib import Path +from typing import Optional + +import ray +import requests + +from arealite.api.cli_args import ( + ClusterSpecConfig, + NameResolveConfig, + SGLangConfig, + parse_cli_args, + to_structured_cfg, +) +from arealite.api.io_struct import AllocationMode, AllocationType +from arealite.utils.launcher import TRITON_CACHE_PATH +from arealite.utils.network import find_free_ports, gethostip +from realhf.base import logging, name_resolve, names, pkg_version + +logger = logging.getLogger("SGLangServer Wrapper") + + +def execute_shell_command(command: str) -> subprocess.Popen: + """ + Execute a shell command and return its process handle. + """ + # Replace newline continuations and split the command string. + command = command.replace("\\\n", " ").replace("\\", " ") + parts = command.split() + _env = os.environ.copy() + # To avoid DirectoryNotEmpty error caused by triton + triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH) + unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4())) + _env["TRITON_CACHE_PATH"] = unique_triton_cache_path + return subprocess.Popen( + parts, + text=True, + env=_env, + stdout=sys.stdout, + stderr=subprocess.STDOUT, + ) + + +def apply_sglang_patch(): + p = Path(os.path.dirname(__file__)) + patch_path = str( + p.parent.parent + / "patch" + / "sglang" + / f"v{pkg_version.get_version('sglang')}.patch" + ) + + target_path = "" + sglang_meta = subprocess.check_output( + "python3 -m pip show sglang", shell=True + ).decode("ascii") + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Editable project location: "): + target_path = str(Path(line.split(": ")[1]).parent) + + if target_path: + proc = subprocess.Popen( + ["git", "apply", patch_path], + cwd=target_path, + stderr=sys.stdout, + stdout=sys.stdout, + ) + proc.wait() + logger.info(f"Applied SGLang patch at {target_path}") + + +def launch_server_cmd(command: str): + """ + Launch the server using the given command. + If no port is specified, a free port is reserved. + """ + if not ray.is_initialized(): + apply_sglang_patch() + process = execute_shell_command(command) + return process + + +def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None: + """Wait for the server to be ready by polling the /v1/models endpoint. + + Args: + base_url: The base URL of the server + timeout: Maximum time to wait in seconds. None means wait forever. + """ + start_time = time.time() + while True: + try: + response = requests.get( + f"{base_url}/v1/models", + headers={"Authorization": "Bearer None"}, + ) + if response.status_code == 200: + time.sleep(5) + break + + if timeout and time.time() - start_time > timeout: + raise TimeoutError("Server did not become ready within timeout period") + except requests.exceptions.RequestException: + time.sleep(1) + + +class SGLangServerWrapper: + def __init__( + self, + experiment_name: str, + trial_name: str, + sglang_config: SGLangConfig, + tp_size: int, + n_gpus_per_node: int, + ): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.config = sglang_config + self.tp_size = tp_size + self.server_process = None + self.n_gpus_per_node = n_gpus_per_node + + def run(self): + gpus_per_server = len(os.getenv("CUDA_VISIBLE_DEVICES").split(",")) + server_local_idx = ( + int(os.getenv("CUDA_VISIBLE_DEVICES").split(",")[0]) // gpus_per_server + ) + n_servers_per_node = max(1, self.n_gpus_per_node // gpus_per_server) + ports_per_server = 40000 // n_servers_per_node + port_range = ( + server_local_idx * ports_per_server + 10000, + (server_local_idx + 1) * ports_per_server + 10000, + ) + server_port, dist_init_port = find_free_ports(2, port_range) + + dist_init_addr = f"localhost:{dist_init_port}" + host_ip = gethostip() + + cmd = SGLangConfig.build_cmd( + self.config, + tp_size=self.tp_size, + base_gpu_id=0, + host=host_ip, + port=server_port, + dist_init_addr=dist_init_addr, + ) + self.server_process = launch_server_cmd(cmd) + wait_for_server(f"http://{host_ip}:{server_port}") + + name = names.gen_servers(self.experiment_name, self.trial_name) + name_resolve.add_subentry(name, f"{host_ip}:{server_port}") + + logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}") + return_code = self.server_process.wait() + logger.info( + f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}" + ) + + def __del__(self): + if self.server_process and self.server_process.poll() is None: + logger.info("Terminating SGLang server process...") + self.server_process.terminate() + self.server_process.wait() + logger.info("SGLang server process terminated.") + + +def main(argv): + config, _ = parse_cli_args(argv) + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + config.cluster.name_resolve = to_structured_cfg( + config.cluster.name_resolve, NameResolveConfig + ) + name_resolve.reconfigure(config.cluster.name_resolve) + + allocation_mode = config.allocation_mode + allocation_mode = AllocationMode.from_str(allocation_mode) + assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG + tp_size = allocation_mode.gen_tp_size + + sglang_server = SGLangServerWrapper( + config.experiment_name, + config.trial_name, + config.sglang, + tp_size, + n_gpus_per_node=config.cluster.n_gpus_per_node, + ) + sglang_server.run() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/arealite/launcher/slurm.py b/arealite/launcher/slurm.py new file mode 100644 index 0000000..ca91112 --- /dev/null +++ b/arealite/launcher/slurm.py @@ -0,0 +1,528 @@ +import getpass +import os +import re +import subprocess +import sys +import time +from typing import Dict, List, Optional, Tuple + +import realhf.base.logging as logging +from arealite.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + SGLangConfig, + parse_cli_args, + to_structured_cfg, +) +from arealite.api.io_struct import AllocationMode, AllocationType +from arealite.utils.launcher import ( + get_env_vars, + validate_config_for_distributed_launcher, + wait_sglang_server_addrs, +) +from arealite.utils.slurm import ( + APPTAINER_CMD_TEMPLATE, + SBATCH_SCRIPT_TEMPLATE, + SRUN_CMD_TEMPLATE, + cancel_jobs, + query_jobs, +) +from realhf.base import logging, name_resolve, names +from realhf.scheduler.client import JobException, JobInfo, JobState + +logger = logging.getLogger("SlurmLauncher") + +SLURM_WAIT_CHECK_TIME_INTERVAL = 5 + + +class SlurmLauncher: + def __init__( + self, experiment_name: str, trial_name: str, fileroot: str, container_type: str + ): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.fileroot = fileroot + self.container_type = container_type + + # slurm_job_id -> JobInfo + self.jobs: Dict[int, JobInfo] = {} + self.job_names = [] + + @property + def run_name(self) -> str: + """Returns the run name of this launcher.""" + return f"{self.experiment_name}_{self.trial_name}" + + def slurm_name(self, job_name: str) -> str: + """Returns the slurm name of a job.""" + return f"{self.experiment_name}_{self.trial_name}:{job_name}" + + def log_path_of(self, job_name: str) -> str: + log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(log_path, exist_ok=True) + return os.path.join(log_path, f"{job_name}.log") + + def sbatch_path_of(self, job_name: str) -> str: + sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(sbatch_path, exist_ok=True) + return os.path.join(sbatch_path, f"{job_name}.sh") + + def submit(self, job_name, cmd, **kwargs): + """Submits and launch a job with SBATCH. + + Args: + cmd (str or List[str]): The core command to be executed. + """ + return self.submit_array(job_name, cmd, count=1, **kwargs) + + def find_job_id(self, job_name: str): + job_name = self.slurm_name(job_name) + for job_id, job_info in self.jobs.items(): + if job_info.name == job_name: + return job_id + return None + + def submit_array( + self, + job_name: str, + cmd: List[str] | str, + count: int, + nodes: int, + n_gpus_per_node: int, + cpus_per_task: int, + mem_per_task: int, # MB + container_image: str, + srun_additional_args: str = "", + container_mounts: Optional[str] = None, + env_vars: Optional[Dict] = None, + nodelist: Optional[str] = None, + exclude: Optional[str] = None, + ): + """Submits and launch a job array with SBATCH. + Note that a job array has one (unique) slurm name, and one (unique) slurm id. + + Args: + job_name (str): The job name of the job array. The actual slurm name will be + `_:`. + cmd (str or List[str]): The core command to be executed. + count (int): The number of jobs in the array. + """ + assert job_name not in self.job_names, ( + f"Job {job_name} is already submitted. " + "Please use a different job name or stop the existing job." + ) + if isinstance(cmd, str): + cmd = [cmd] + assert len(cmd) == count, ( + f"Command length {len(cmd)} does not match the job count {count}. " + "Please provide a command for each job in the array." + ) + assert count % nodes == 0, ( + f"Job count {count} must be divisible by the number of nodes {nodes}. " + "Please adjust the job count or the number of nodes." + ) + ntasks_per_node = count // nodes + assert n_gpus_per_node % ntasks_per_node == 0, ( + "GPUs must be evenly distributed across tasks. " + f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}." + ) + + mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU + mem_per_node = ( + mem_per_task * count // nodes + 1024 * 10 + ) # make sure slurm does not run out of resources + + sbatch_options = [ + f"--job-name={self.slurm_name(job_name)}", + f"--output={self.log_path_of(job_name)}", + "--open-mode=append", + "--no-requeue", + f"--nodes={nodes}-{nodes}", + f"--ntasks-per-node={ntasks_per_node}", + f"--gres=gpu:{n_gpus_per_node}", + f"--cpus-per-task={cpus_per_task}", + f"--mem={mem_per_node}M", + ] + + if nodelist: + sbatch_options.append(f"--nodelist={nodelist}") + if exclude: + sbatch_options.append(f"--exclude={exclude}") + + sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options]) + + if env_vars is None: + env_vars = dict() + n_gpus_per_task = n_gpus_per_node // ntasks_per_node + assert ( + "CUDA_VISIBLE_DEVICES" not in env_vars + ), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned." + + srun_cmds = [] + for i in range(count): + # resolve CUDA_VISIBLE_DEVICES for each task + gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task + gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task + node_id = i // ntasks_per_node + _env_vars = { + **env_vars, + "CUDA_VISIBLE_DEVICES": ",".join( + str(x) for x in range(gpu_id_start, gpu_id_end) + ), + } + # Prepare the command for each job in the array + job_cmd = cmd[i] + + if self.container_type == "apptainer": + env_string = " ".join( + "--env {}={}".format(k, v) for k, v in _env_vars.items() + ) + apptainer_cmd = APPTAINER_CMD_TEMPLATE.format( + container_mounts=container_mounts or "", + container_env_strings=env_string, + container_image=container_image, + cmd=job_cmd, + ) + srun_cmd = SRUN_CMD_TEMPLATE.format( + additional_args=srun_additional_args, + nodes=1, + ntasks=1, + node_id=node_id, + n_gpus_per_node=n_gpus_per_task, + cpus_per_task=cpus_per_task, + mem_per_cpu=mem_per_cpu, + cmd=apptainer_cmd, + ) + elif self.container_type == "none": + env_string = "--export=" + ",".join( + "{}={}".format(k, v) for k, v in _env_vars.items() + ) + srun_additional_args = srun_additional_args + " " + env_string + srun_cmd = SRUN_CMD_TEMPLATE.format( + additional_args=srun_additional_args, + nodes=1, + ntasks=1, + node_id=node_id, + n_gpus_per_node=n_gpus_per_task, + cpus_per_task=cpus_per_task, + mem_per_cpu=mem_per_cpu, + cmd=job_cmd, + ) + else: + raise ValueError( + f"Unsupported container type: {self.container_type}. " + "Supported types are 'apptainer' and 'none'." + ) + srun_cmds.append(srun_cmd) + + srun_cmds = "\n".join(srun_cmds) + sbatch_script = SBATCH_SCRIPT_TEMPLATE.format( + sbatch_options=sbatch_options_str, + srun_additional_args=srun_additional_args, + srun_cmds=srun_cmds, + ) + sbatch_file_path = self.sbatch_path_of(f"{job_name}") + with open(sbatch_file_path, "w") as f: + f.write(sbatch_script) + + # Submit the job + try: + output = ( + subprocess.check_output(["sbatch", sbatch_file_path]) + .decode("utf-8") + .strip() + ) + logger.info( + f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`." + ) + except subprocess.CalledProcessError as e: + logger.warning( + f"Failed to submit job {self.slurm_name(job_name)}. " + f"For debugging, please make sure your sbatch command works " + f"and check generated sbatch file on {sbatch_file_path}." + ) + logger.error(f"Error message: {e}") + return + + match = re.search(r"Submitted batch job (\d+)", output) + slurm_job_id = int(match.group(1)) if match else None + if slurm_job_id is None: + logger.warning( + f"Failed to obtain job id for job {self.slurm_name(job_name)}. " + f"sbatch output: {output}" + ) + return + + assert isinstance(slurm_job_id, int) + self.jobs[slurm_job_id] = JobInfo( + name=self.slurm_name(job_name), + state=JobState.PENDING, + slurm_id=slurm_job_id, + ) + self._update_all() + + def stop(self, job_name, force=False): + """Stops a running job. + + Raises exception if there is no such job, but passes if the job + has stopped either successfully or not. + + Args: + job_name: The job name of the job array to stop. + The actual slurm job name will be `_:`. + """ + signal = "SIGKILL" if force else "SIGTERM" + job_id = self.find_job_id(job_name) + if not job_id: + return + return cancel_jobs(slurm_ids=[job_id], signal=signal) + + def stop_all(self, force=False): + """Stops all running jobs.""" + signal = "SIGKILL" if force else "SIGTERM" + return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal) + + def find(self, job_name) -> JobInfo | None: + """Gets the status of a job of this job. + + Args: + job_name: The job name of the job array to find. + The actual slurm job name will be `_:`. + + Returns: + A JobInfo if the job is found, or None otherwise. + """ + self._update_all() + job_id = self.find_job_id(job_name) + return self.jobs[job_id] if job_id else None + + def find_all(self, job_name_regex=".*") -> List[JobInfo]: + """Finds jobs. + + Args: + job_name_regex: job name regex. + + Returns: + A list of found JobInfo. + """ + self._update_all() + infos = [] + for r in self.jobs.values(): + job_name = r.name.split(":")[-1] # Extract the job name from slurm name + if re.fullmatch(job_name_regex, job_name): + infos.append(r) + return infos + + def _find_job_with_status( + self, + status: List[JobState], + ) -> List[JobInfo]: + """Finds jobs with the given status. + + Args: + status: A list of JobState to filter jobs. + + Returns: + A list of JobInfo with the given status. + """ + self._update_all() + return [r for r in self.jobs.values() if r.state in status] + + def wait( + self, + timeout=None, + check_status: Tuple[JobState, ...] = ( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + ), + remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,), + update=False, + ): + """Waits until all jobs submitted via this client instance finish.""" + # begin wait + deadline = None if timeout is None else time.time() + timeout + + num_jobs_left = len(self.jobs) + left = list(self.jobs.keys()) + logger.info( + f"Waiting for {num_jobs_left} jobs. Jobs IDs: " + f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}." + ) + while len(left) > 0: + if len(left) < num_jobs_left: + num_jobs_left = len(left) + logger.info(f"Waiting for {num_jobs_left} jobs.") + if deadline is not None and time.time() > deadline: + raise TimeoutError( + f"Timeout waiting for {num_jobs_left} jobs. Job ID: " + f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}." + ) + self._update_all() + left = list(self.jobs.keys()) + for slurm_id in list(left): + slurm_info = self.jobs[slurm_id] + if slurm_info.slurm_id is None: + continue + if slurm_info.state in check_status: + raise JobException( + run_name=self.run_name, + worker_type=slurm_info.name, + host=slurm_info.host, + reason=slurm_info.state, + ) + if slurm_info.state in remove_status: + logger.info( + f"Job {slurm_info.name} is {slurm_info.state}. (Removed)" + ) + left.remove(slurm_id) + if update: + self.jobs.pop(slurm_info.slurm_id) + time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL) + + def _update_all(self): + """Updates the status of all jobs.""" + try: + slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys())) + for slurm_info in slurm_infos: + assert slurm_info.slurm_id is not None + self.jobs[slurm_info.slurm_id] = slurm_info + except subprocess.CalledProcessError: + logger.warning( + "Calling squeue failed. Check slurm manually if you continue to see this warning." + ) + + +def slurm_main(): + config, _ = parse_cli_args(sys.argv[2:]) + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + validate_config_for_distributed_launcher(config) + + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + + n_nodes = config.cluster.n_nodes + n_gpus_per_node = config.cluster.n_gpus_per_node + + launcher = SlurmLauncher( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + fileroot=config.cluster.fileroot, + container_type=config.launcher.slurm.container_type, + ) + allocation_mode = config.allocation_mode + allocation_mode = AllocationMode.from_str(allocation_mode) + sglang_cmds = [] + sglang_addrs = [] + n_sglang_nodes = 0 + if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: + # Launcher should launch SGLang servers according to allocation mode. + sglang_tp_size = allocation_mode.gen_tp_size + n_sglang_servers = allocation_mode.gen_dp_size + n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node + + base_seed = config.sglang.random_seed + sglang_server_cmd_template = f"python3 -m arealite.launcher.sglang_server {' '.join(sys.argv[2:])} sglang.random_seed={{seed}}" + for i in range(n_sglang_servers): + sglang_cmd = sglang_server_cmd_template.format( + seed=base_seed + i, + ) + sglang_cmds.append(sglang_cmd) + + launcher.submit_array( + job_name="llm_server", + cmd=sglang_cmds, + count=n_sglang_servers, + nodes=n_sglang_nodes, + n_gpus_per_node=config.cluster.n_gpus_per_node, + cpus_per_task=config.launcher.inference_server_cpus_per_gpu + * sglang_tp_size, + mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size, + srun_additional_args=config.launcher.slurm.srun_additional_args, + container_image=config.launcher.slurm.inference_server_image, + container_mounts=config.launcher.slurm.mount, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + # Get SGLang server addresses by name resolve + try: + sglang_addrs = wait_sglang_server_addrs( + config.experiment_name, + config.trial_name, + n_sglang_servers, + ) + except TimeoutError as e: + launcher.stop_all(force=True) + raise e + + trainer_n_nodes = n_nodes - n_sglang_nodes + # Here $head_node_ip is the IP address of the first node in the job array. + # $trainer_port is a free port on the head node. + # Both of them are obtained in by the SBATCH script. + trainer_cmd_template = ( + f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} " + f"--master-addr $head_node_ip --master-port $trainer_port {' '.join(sys.argv[1:])}" + ) + + trainer_cmds = [] + for i in range(trainer_n_nodes): + # In slurm, we launch trainer in the granularity of nodes with torchrun command. + trainer_cmds.append( + trainer_cmd_template.format( + nnodes=trainer_n_nodes, + nproc_per_node=config.cluster.n_gpus_per_node, + node_rank=i, + ) + ) + + if not config.server_only: + # launch trainers + launcher.submit_array( + job_name="trainer", + cmd=trainer_cmds, + count=trainer_n_nodes, + nodes=trainer_n_nodes, + n_gpus_per_node=config.cluster.n_gpus_per_node, + cpus_per_task=config.launcher.trainer_cpus_per_gpu + * config.cluster.n_gpus_per_node, + mem_per_task=config.launcher.trainer_mem_per_gpu + * config.cluster.n_gpus_per_node, + container_image=config.launcher.slurm.trainer_image, + srun_additional_args=config.launcher.slurm.srun_additional_args, + container_mounts=config.launcher.slurm.mount, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs), + ), + ) + + try: + launcher.wait( + check_status=( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + JobState.COMPLETED, + ), + remove_status=(), + ) + except (KeyboardInterrupt, JobException, TimeoutError) as e: + launcher.stop_all(force=True) + raise e + + +if __name__ == "__main__": + # usage: python -m arealite.launcher.slurm \ + # --config [] + slurm_main() diff --git a/arealite/tests/test_fsdp_engine_nccl.py b/arealite/tests/test_fsdp_engine_nccl.py index a555a02..7d71fd6 100644 --- a/arealite/tests/test_fsdp_engine_nccl.py +++ b/arealite/tests/test_fsdp_engine_nccl.py @@ -2,11 +2,9 @@ import os import subprocess import sys import time -import uuid import pytest import requests -import torch from arealite.api.cli_args import ( InferenceEngineConfig, @@ -18,7 +16,6 @@ from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta from arealite.engine.fsdp_engine import FSDPEngine from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.utils.network import find_free_ports -from realhf.api.core.data_api import load_hf_tokenizer from realhf.base import network EXPR_NAME = "test_fsdp_engine_nccl" diff --git a/arealite/utils/data.py b/arealite/utils/data.py index b5f3174..d6a6677 100644 --- a/arealite/utils/data.py +++ b/arealite/utils/data.py @@ -399,9 +399,15 @@ def pad_packed_tensor_dict( padded_data[key] = new_max_seqlen elif torch.is_tensor(value) and value.numel() == total_length: # Pad the tensor to the new total length - padded_tensor = torch.nn.functional.pad( - value, (0, pad_length), value=pad_value - ) + if key == "position_ids": + # transformers will compute flash-attn arguments (e.g., cu_seqlens_q) + # according to this position ids. + pad = torch.arange(pad_length, dtype=torch.long, device=value.device) + padded_tensor = torch.cat([value, pad]) + else: + padded_tensor = torch.nn.functional.pad( + value, (0, pad_length), value=pad_value + ) padded_data[key] = padded_tensor else: padded_data[key] = value diff --git a/arealite/utils/launcher.py b/arealite/utils/launcher.py new file mode 100644 index 0000000..3931cc0 --- /dev/null +++ b/arealite/utils/launcher.py @@ -0,0 +1,102 @@ +import getpass +import os +import pathlib +import time +from typing import Dict, Optional + +from arealite.api.io_struct import AllocationMode, AllocationType +from realhf.base import logging, name_resolve, names + +logger = logging.getLogger("Launcher Utils") + +LOCAL_CACHE_DIR = "/tmp/arealite" +PYTORCH_KERNEL_CACHE_PATH = ( + f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/" +) +TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton/" +os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) +os.makedirs(TRITON_CACHE_PATH, exist_ok=True) +BASE_ENVIRONS = { + "TOKENIZERS_PARALLELISM": "true", + "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, + "TRITON_CACHE_DIR": TRITON_CACHE_PATH, + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "PYTHONPATH": str(pathlib.Path(__file__).resolve().parent.parent.parent), +} +NA132_ENVIRONS = { + "NCCL_SOCKET_IFNAME": "bond0", + "NCCL_NET_PLUGIN": "", + "NCCL_IB_GID_INDEX": "3", + "NCCL_IB_TIMEOUT": "2", + "NCCL_IB_RETRY_CNT": "7", + "NCCL_IB_SL": "5", + "NCCL_IB_TC": "136", + "NCCL_IB_HCA": "mlx5_bond", + "NCCL_IB_QPS_PER_CONNECTION": "8", + "NCCL_SET_THREAD_NAME": "1", + "NCCL_DEBUG": "WARN", + "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", +} +SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180 + + +def get_env_vars( + cluster_name: str, additional_env_vars: Optional[str] = None +) -> Dict[str, str]: + """Returns the environment variables for the cluster.""" + _additional_env_vars = ( + dict(item.split("=") for item in additional_env_vars.split(",")) + if additional_env_vars + else dict() + ) + if cluster_name == "na132": + return {**BASE_ENVIRONS, **NA132_ENVIRONS, **_additional_env_vars} + else: + return {**BASE_ENVIRONS, **_additional_env_vars} + + +def wait_sglang_server_addrs( + experiment_name: str, + trial_name: str, + n_sglang_servers: int, +): + # Get SGLang slurm nodes, find the hosts + name = names.gen_servers(experiment_name, trial_name) + start = time.perf_counter() + while True: + sglang_addrs = name_resolve.get_subtree(name) + if len(sglang_addrs) >= n_sglang_servers: + logger.info( + f"Found {len(sglang_addrs)} SGLang servers: {', '.join(sglang_addrs)}" + ) + break + + time.sleep(1) + if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS: + raise TimeoutError( + f"Timeout waiting for SGLang servers to be ready. " + f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}." + ) + return sglang_addrs + + +def validate_config_for_distributed_launcher(config): + n_nodes = config.cluster.n_nodes + n_gpus_per_node = config.cluster.n_gpus_per_node + allocation_mode = config.allocation_mode + allocation_mode = AllocationMode.from_str(allocation_mode) + assert ( + allocation_mode.gen_world_size + allocation_mode.train_world_size + == n_nodes * n_gpus_per_node + ), ( + f"#GPUs required for allocation mode {allocation_mode.gen_world_size + allocation_mode.train_world_size} " + f"is not equal to #GPUs in the config {n_nodes * n_gpus_per_node}." + ) + if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: + # Launcher should launch SGLang servers according to allocation mode. + assert ( + allocation_mode.gen_pp_size == 1 + ), "Pipeline generation in SGLang is not supported for now." + assert ( + allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node + ), "Currently only support SGLang TP size less <= #GPUs per node." diff --git a/arealite/utils/padding.py b/arealite/utils/padding.py deleted file mode 100644 index 6151322..0000000 --- a/arealite/utils/padding.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List - -import torch -from tensordict import TensorDict - - -def concat_padded_tensors( - tensor_dicts: List[TensorDict], pad_value: float = 0.0 -) -> TensorDict: - """Concatenate and pad tensors from multiple padded tensor dictionaries.""" - if not tensor_dicts: - return TensorDict() - - batch_sizes = [tuple(d.batch_size) for d in tensor_dicts] - new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]] - - # Find max sequence length across all dictionaries - assert all("attention_mask" in td for td in tensor_dicts) - max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts]) - result = {} - # Process each key - for key in tensor_dicts[0].keys(): - tensors_to_concat = [] - for tensor_dict in tensor_dicts: - tensor = tensor_dict[key] - # Skip 1D tensors like rewards - if len(tensor.shape) == 1: - tensors_to_concat.append(tensor) - continue - current_length = tensor.shape[1] - if current_length < max_length: - # Pad tensor to max_length - pad_width = max_length - current_length - if key == "attention_mask": - # Pad attention mask with 0s - padding = torch.zeros( - (tensor.shape[0], pad_width), dtype=tensor.dtype - ) - else: - # Pad feature tensors with pad_value - padding = torch.full( - (tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype - ) - tensor = torch.cat([tensor, padding], dim=1) - tensors_to_concat.append(tensor) - - result[key] = torch.cat(tensors_to_concat, dim=0) - return TensorDict(result, batch_size=new_batch_size) diff --git a/arealite/utils/ray.py b/arealite/utils/ray.py new file mode 100644 index 0000000..f049a72 --- /dev/null +++ b/arealite/utils/ray.py @@ -0,0 +1,23 @@ +import ray +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from arealite.utils.network import find_free_ports, gethostip + + +def get_placement_group_master_ip_and_port(placement_group: PlacementGroup): + def _master_ip_and_port(): + host_ip = gethostip() + port = find_free_ports(1, (10000, 60000))[0] + return host_ip, port + + future = ray.remote( + num_cpus=1, + num_gpus=0, + memory=10 * 1024 * 1024, # Convert MB to bytes + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=0, + ), + )(_master_ip_and_port).remote() + return ray.get(future) diff --git a/arealite/utils/slurm.py b/arealite/utils/slurm.py new file mode 100644 index 0000000..f876c5e --- /dev/null +++ b/arealite/utils/slurm.py @@ -0,0 +1,154 @@ +import subprocess +from typing import List, Literal, Optional + +from realhf.base import logging +from realhf.scheduler.client import JobInfo, JobState + +logger = logging.getLogger("Slurm Utils") + + +SQUEUE_FIELDS = [ + "JobID", + "State", + "SubmitTime", + "StartTime", + "Name", + "NodeList", + "UserName", + "MaxCPUs", + "cpus-per-task", + "NumTasks", + "tres-alloc", +] +STATUS_MAPPING = { + "RUNNING": JobState.RUNNING, + "COMPLETING": JobState.RUNNING, + "PENDING": JobState.PENDING, + "CANCELLED": JobState.CANCELLED, + "FAILED": JobState.FAILED, + "COMPLETED": JobState.COMPLETED, + "OUT_OF_MEMORY": JobState.FAILED, + "DEADLINE": JobState.COMPLETED, + "TIMEOUT": JobState.COMPLETED, +} + +SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash +{sbatch_options} + +# Getting the node names +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +echo nodes=$nodes + +nodes_array=($nodes) +echo node_array=$nodes_array + +head_node=${{nodes_array[0]}} +echo head_node=$head_node + +# Getting the head node IP address +head_node_ip=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address) +echo head_node_ip=$head_node_ip + +# Find a free port on the head node +# Wonderful linux command to find a random free port (between 10000 and 60000) by deepseek +trainer_port=$(srun {srun_additional_args} --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" bash -c "comm -23 <(seq 10000 60000 | sort) <(ss -tan | awk '{{print $4}}' | cut -d':' -f2 | grep '[0-9]\\{{1,5\\}}' | sort -u) | shuf | head -n 1") +echo trainer_port=$trainer_port + +# srun commands +{srun_cmds} + +wait +""" + +SRUN_CMD_TEMPLATE: str = """srun {additional_args} \\ + --nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\ + --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\ + {cmd} & +""" + +APPTAINER_CMD_TEMPLATE: str = """singularity exec --no-home --writable-tmpfs --nv --pid \\ + --bind {container_mounts} \\ + {container_env_strings} \\ + {container_image} \\ + {cmd}""" + + +def cancel_jobs( + slurm_names: Optional[List[str]] = None, + slurm_ids: Optional[List[int]] = None, + signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL", +): + assert ( + slurm_names is not None or slurm_ids is not None + ), "Must specify slurm_names or slurm_ids." + assert not ( + slurm_names and slurm_ids + ), "Cannot specify both slurm_names and slurm_ids." + cmd = ["scancel", "-s", signal] + if slurm_names is not None: + cmd += ["-n", ",".join(slurm_names)] + elif slurm_ids is not None: + cmd += [",".join(str(s) for s in slurm_ids)] + subprocess.check_call(cmd) + logger.info( + f"Cancelled Slurm job with signal {signal}: " + f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}" + ) + + +def query_jobs( + slurm_names: Optional[List[str]] = None, + slurm_ids: Optional[List[int]] = None, + status: str = "all", + delimiter: str = "__PSI__", +) -> List[JobInfo]: + squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS) + cmd = ["squeue", "-O", squeue_format, f"-t{status}"] + if slurm_names is not None: + cmd += ["-n", ",".join(slurm_names)] + if slurm_ids is not None: + cmd += ["-j", ",".join([str(s) for s in slurm_ids])] + + output = ( + subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip() + ) + rs = [] + for line in output.split("\n")[1:]: + job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split( + delimiter + ) + rs.append( + JobInfo( + name=slurm_name, + state=STATUS_MAPPING[state], + host=nodelist, + submit_time=submit_time, + start_time=start_time, + slurm_id=int(job_id.strip()), + ) + ) + return rs + + +def parse_slurm_nodelist(nodelist: str) -> List[str]: + return ( + subprocess.check_output( + [ + "scontrol", + "show", + "hostnames", + nodelist, + ] + ) + .decode("utf-8") + .strip() + .split("\n") + ) + + +def get_slurm_host_ip(node: str, srun_addtional_args: str): + try: + cmd = f"srun {srun_addtional_args} --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address" + return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip() + except subprocess.CalledProcessError: + logger.warning(f"Get slurm host IP for node {node} failed.") diff --git a/arealite/utils/stats_logger.py b/arealite/utils/stats_logger.py index 49c0d55..4c2612b 100644 --- a/arealite/utils/stats_logger.py +++ b/arealite/utils/stats_logger.py @@ -73,7 +73,7 @@ class StatsLogger: ) if isinstance(data, Dict): data = [data] - log_step = max(global_step, self._last_commit_step) + log_step = max(global_step, self._last_commit_step + 1) for i, item in enumerate(data): self.info(f"Stats ({i+1}/{len(data)}):") self.print_stats(item) diff --git a/arealite/workflow/rlvr.py b/arealite/workflow/rlvr.py index 3ce55df..5404972 100644 --- a/arealite/workflow/rlvr.py +++ b/arealite/workflow/rlvr.py @@ -1,11 +1,14 @@ import asyncio +import os import uuid +import colorama import torch from tensordict import TensorDict from transformers import PreTrainedTokenizerFast from arealite.api.cli_args import GenerationHyperparameters +from arealite.api.engine_api import InferenceEngine from arealite.api.io_struct import LLMRequest from arealite.api.workflow_api import RolloutWorkflow from arealite.utils.data import concat_padded_tensors @@ -18,13 +21,17 @@ class RLVRWorkflow(RolloutWorkflow): gconfig: GenerationHyperparameters, tokenizer: PreTrainedTokenizerFast, enable_thinking: bool, + dump_dir: str | None = None, ): self.reward_fn = reward_fn self.gconfig = gconfig self.tokenizer = tokenizer self.enable_thinking = enable_thinking + self.dump_dir = dump_dir + if self.dump_dir is not None and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) - async def arun_episode(self, engine, data): + async def arun_episode(self, engine: InferenceEngine, data): input_ids = self.tokenizer.apply_chat_template( data["messages"], tokenize=True, @@ -39,6 +46,12 @@ class RLVRWorkflow(RolloutWorkflow): ) resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) + version = engine.get_version() + prompt_strs = [] + completions_strs = [] + rewards = [] + seqlens = [] + results = [] for resp in resps: seq = resp.input_tokens + resp.output_tokens @@ -46,13 +59,19 @@ class RLVRWorkflow(RolloutWorkflow): loss_mask = [0] * resp.input_len + [1] * resp.output_len versions = [-1] * resp.input_len + resp.output_versions + prompt_str = self.tokenizer.decode(input_ids) + completions_str = self.tokenizer.decode(resp.output_tokens) + prompt_strs.append(prompt_str) + completions_strs.append(completions_str) + seqlens.append(len(seq)) reward = self.reward_fn( - prompt=self.tokenizer.decode(input_ids), - completions=self.tokenizer.decode(resp.output_tokens), + prompt=prompt_str, + completions=completions_str, prompt_ids=resp.input_tokens, completion_ids=resp.output_tokens, **data, ) + rewards.append(reward) res = dict( # unsqueeze to add an additional batch dimension input_ids=torch.tensor(seq).unsqueeze(0), @@ -65,4 +84,31 @@ class RLVRWorkflow(RolloutWorkflow): ) results.append(TensorDict(res, batch_size=[1])) + if self.dump_dir is not None: + os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True) + # Get the unique identifier for this prompt + qid = None + for key in ["query_id", "id", "qid"]: + qid = data.get(key, None) + if qid is not None: + break + qid = qid or uuid.uuid4().hex + + # Dump rollout to file + with open( + os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a" + ) as f: + n_samples = self.gconfig.n_samples + for i, (p, c, r, sl) in enumerate( + zip(prompt_strs, completions_strs, rewards, seqlens) + ): + info = "\n".join( + [ + f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.", + f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}", + f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}", + ] + ) + f.write(info + "\n") + return concat_padded_tensors(results) diff --git a/examples/arealite/boba.py b/examples/arealite/boba.py new file mode 100644 index 0000000..4211d11 --- /dev/null +++ b/examples/arealite/boba.py @@ -0,0 +1,310 @@ +import asyncio +import os +import shutil +import sys +import uuid + +import colorama +import torch +import torch.distributed as dist +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +from tensordict import TensorDict +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import PreTrainedTokenizerFast + +from arealite.api.cli_args import ( + GenerationHyperparameters, + GRPOConfig, + load_expr_config, +) +from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta +from arealite.api.workflow_api import RolloutWorkflow +from arealite.engine.ppo.actor import FSDPPPOActor +from arealite.engine.sglang_remote import RemoteSGLangEngine +from arealite.utils.data import concat_padded_tensors +from arealite.utils.device import log_gpu_stats +from arealite.utils.saver import Saver +from arealite.utils.stats_logger import StatsLogger +from realhf.api.core.data_api import load_hf_tokenizer +from realhf.base import logging, seeding, stats_tracker + +logger = logging.getLogger("boba math") + + +class RLVRWorkflow(RolloutWorkflow): + def __init__( + self, + reward_fn, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast, + dump_dir: str | None = None, + ): + self.reward_fn = reward_fn + self.gconfig = gconfig + self.tokenizer = tokenizer + self.dump_dir = dump_dir + if self.dump_dir is not None and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + + async def arun_episode(self, engine, data): + input_ids = self.tokenizer.encode(data["prompt"]) + n_samples = self.gconfig.n_samples + req = LLMRequest( + rid=uuid.uuid4().hex, + input_ids=input_ids, + gconfig=self.gconfig.new(n_samples=1), + ) + resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) + + version = engine.get_version() + prompt_strs = [] + completions_strs = [] + rewards = [] + seqlens = [] + + results = [] + for resp in resps: + seq = resp.input_tokens + resp.output_tokens + logprobs = [0.0] * resp.input_len + resp.output_logprobs + loss_mask = [0] * resp.input_len + [1] * resp.output_len + versions = [-1] * resp.input_len + resp.output_versions + + prompt_str = data["prompt"] + completions_str = self.tokenizer.decode(resp.output_tokens) + prompt_strs.append(prompt_str) + completions_strs.append(completions_str) + seqlens.append(len(seq)) + reward = self.reward_fn( + completions=completions_str, + prompt_ids=resp.input_tokens, + completion_ids=resp.output_tokens, + **data, + ) + rewards.append(reward) + res = dict( + # unsqueeze to add an additional batch dimension + input_ids=torch.tensor(seq).unsqueeze(0), + loss_mask=torch.tensor(loss_mask).unsqueeze(0), + logprobs=torch.tensor(logprobs).unsqueeze(0), + versions=torch.tensor(versions).unsqueeze(0), + attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), + # reward + rewards=torch.tensor([float(reward)]), + ) + results.append(TensorDict(res, batch_size=[1])) + + if self.dump_dir is not None: + os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True) + # Get the unique identifier for this prompt + qid = None + for key in ["query_id", "id", "qid"]: + qid = data.get(key, None) + if qid is not None: + break + qid = qid or uuid.uuid4().hex + + # Dump rollout to file + with open( + os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a" + ) as f: + n_samples = self.gconfig.n_samples + for i, (p, c, r, sl) in enumerate( + zip(prompt_strs, completions_strs, rewards, seqlens) + ): + info = "\n".join( + [ + f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.", + f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}", + f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}", + ] + ) + f.write(info + "\n") + + return concat_padded_tensors(results) + + +def get_boba_math_dataset(tokenizer, rank, world_size): + dataset = load_dataset( + path="json", + split="train", + data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl", + ) + dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024) + return split_dataset_by_node(dataset, rank=rank, world_size=world_size) + + +def boba_reward_fn( + prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs +): + from pebble import ProcessExpired, ProcessPool + + from realhf.impl.dataset.math_parser import process_results + + jobs = [] + with ProcessPool(max_workers=1) as executor: + for sol in solutions: + job = executor.schedule( + process_results, args=[completions, sol], timeout=15 + ) + jobs.append(job) + + label = 0 + for job in jobs: + try: + x = job.result() + except TimeoutError: + # print("[debug: timeout]") + logger.warning(f"Timeout occurred while justifying the math answer.") + x = (0, "timeout", "timeout") + except ProcessExpired as e: + logger.warning(f"Process terminated abnormally: {e}") + x = (0, "error", "error") + except Exception as e: + logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}") + x = (0, "error", "error") + label = label or x[0] + return label + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + rank = int(os.getenv("RANK")) + world_size = int(os.getenv("WORLD_SIZE")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + get_boba_math_dataset(tokenizer, rank, world_size), + batch_size=config.train_dataset.batch_size // world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, + ) + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize inference engine + rollout = RemoteSGLangEngine(config.rollout) + rollout.initialize(None, ft_spec) + + # Initialize train engine + actor = FSDPPPOActor(config=config.actor) + actor.initialize(None, ft_spec) + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.initialize(None, ft_spec) + + # Create rollout workflow + if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + workflow = RLVRWorkflow( + reward_fn=boba_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) + + # Run training. + saver = Saver(config.saver, ft_spec, for_recover=False) + logger = StatsLogger(config.stats_logger, ft_spec) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}") + data_generator = iter(train_dataloader) + for global_step in range(max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + + with stats_tracker.record_timing("rollout"): + if config.async_training: + batch = rollout.prepare_batch(train_dataloader, workflow=workflow) + else: + try: + data = next(data_generator) + except StopIteration: + data_generator = iter(train_dataloader) + data = next(data_generator) + batch = rollout.rollout_batch(data, workflow=workflow) + + batch = batch.to(actor.device) + # Create barrier to synchronize all rollout processes. + dist.barrier() + torch.cuda.synchronize() + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("grpo_actor"), + ): + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + with stats_tracker.record_timing("update_weights"): + path = os.path.join( + Saver.get_save_checkpoint_root(config.saver), + "update_weights", + str(global_step + 1), + ) + meta = WeightUpdateMeta( + type="disk", + path=path, + alloc_mode=None, + comm_backend=None, + model_version=global_step + 1, + ) + if dist.get_rank() == 0: + future = rollout.update_weights(meta) + actor.upload_weights(meta) + if dist.get_rank() == 0: + future.result() + shutil.rmtree(path, ignore_errors=True) + dist.barrier() + torch.cuda.synchronize() + rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step) + + logger.commit(epoch, step, global_step, stats) + + logger.close() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/arealite/configs/boba.yaml b/examples/arealite/configs/boba.yaml new file mode 100644 index 0000000..3924a14 --- /dev/null +++ b/examples/arealite/configs/boba.yaml @@ -0,0 +1,141 @@ +experiment_name: lite-boba-math +trial_name: run1 + +cluster: + n_nodes: 16 + n_gpus_per_node: 8 + cluster_name: na132 + fileroot: /storage/openpsi/experiments + name_resolve: + type: nfs + nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math + etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379 + +seed: 1 +total_train_epochs: 10 +total_train_steps: null +tokenizer_path: ${actor.path} +allocation_mode: sglang.d96p1t1+d32p1t1 +async_training: true + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 400 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 4 + enable_rollout_tracing: true + +gconfig: + n_samples: 16 + min_new_tokens: 0 + max_new_tokens: 30720 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/ + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 32768 + optimizer: + type: adam + lr: 1e-5 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + + group_size: ${gconfig.n_samples} + group_adv_norm: false + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 4 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 32768 + optimizer: null + backend: fsdp + +# SGLang +server_only: false +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.9 + +# datasets +train_dataset: + batch_size: 512 + shuffle: true + pin_memory: true + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +checkpointer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: online + +# Launcher +launcher: + inference_server_cpus_per_gpu: 15 + inference_server_mem_per_gpu: 153600 + trainer_cpus_per_gpu: 15 + trainer_mem_per_gpu: 153600 + slurm: + mount: /storage:/storage + trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif + inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif \ No newline at end of file diff --git a/examples/arealite/configs/gsm8k_grpo.yaml b/examples/arealite/configs/gsm8k_grpo.yaml index ac74885..9f0107e 100644 --- a/examples/arealite/configs/gsm8k_grpo.yaml +++ b/examples/arealite/configs/gsm8k_grpo.yaml @@ -41,7 +41,7 @@ actor: max_tokens_per_mb: 10240 optimizer: type: adam - lr: 2e-6 + lr: 1e-5 weight_decay: 0.01 beta1: 0.9 beta2: 0.999 diff --git a/examples/arealite/gsm8k_grpo.py b/examples/arealite/gsm8k_grpo.py index 1841343..6fbccb7 100644 --- a/examples/arealite/gsm8k_grpo.py +++ b/examples/arealite/gsm8k_grpo.py @@ -1,5 +1,5 @@ import os -import re +import shutil import sys import torch @@ -18,7 +18,9 @@ from arealite.utils.saver import Saver from arealite.utils.stats_logger import StatsLogger from arealite.workflow.rlvr import RLVRWorkflow from realhf.api.core.data_api import load_hf_tokenizer -from realhf.base import stats_tracker +from realhf.base import logging, seeding, stats_tracker + +logger = logging.getLogger("GSM8K grpo") def process_gsm8k_rl_dataset(dataset: Dataset): @@ -36,54 +38,22 @@ def get_gsm8k_dataset(split, rank, world_size): return process_gsm8k_rl_dataset(dataset) -# Adapted from verl. -def extract_solution(solution_str, method="strict") -> str | None: - assert method in ["strict", "flexible"] - - final_answer = None - if method == "strict": - # this also tests the formatting of the model - solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) - if len(solutions) == 0: - final_answer = None - else: - # take the last solution - final_answer = solutions[-1].replace(",", "").replace("$", "") - elif method == "flexible": - answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) - final_answer = None - if len(answer) == 0: - # no reward is there is no answer - pass - else: - invalid_str = ["", "."] - # find the last number that is not '.' - for final_answer in reversed(answer): - if final_answer not in invalid_str: - break - return final_answer - - def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): - from realhf.impl.dataset.math_parser import extract_answer + from realhf.impl.dataset.math_parser import process_results - sol = extract_answer(completions, data_name="math") - ans = extract_solution(solution_str=answer, method="strict") - if sol is None: - return 0 - if ans is None: - return 0 - return int(sol.strip() == ans.strip()) + return int(process_results(completions, answer)[0]) -def main_grpo(): - config, _ = load_expr_config(sys.argv[1:], GRPOConfig) +def main(args): + config, _ = load_expr_config(args, GRPOConfig) config: GRPOConfig rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) tokenizer = load_hf_tokenizer(config.tokenizer_path) + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + # Create dataset and dataloaders train_dataloader = StatefulDataLoader( get_gsm8k_dataset("train", rank, world_size), @@ -133,6 +103,9 @@ def main_grpo(): gconfig=config.gconfig, tokenizer=tokenizer, enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), ) # Run training. @@ -190,13 +163,14 @@ def main_grpo(): log_gpu_stats("ppo update") with stats_tracker.record_timing("update_weights"): + path = os.path.join( + Saver.get_save_checkpoint_root(config.saver), + "update_weights", + str(global_step + 1), + ) meta = WeightUpdateMeta( type="disk", - path=os.path.join( - Saver.get_save_checkpoint_root(config.saver), - "update_weights", - str(global_step), - ), + path=path, alloc_mode=None, comm_backend=None, model_version=global_step + 1, @@ -206,6 +180,7 @@ def main_grpo(): actor.upload_weights(meta) if dist.get_rank() == 0: future.result() + shutil.rmtree(path, ignore_errors=True) dist.barrier() torch.cuda.synchronize() rollout.set_version(global_step + 1) @@ -253,4 +228,4 @@ def main_grpo(): if __name__ == "__main__": - main_grpo() + main(sys.argv[1:]) diff --git a/examples/arealite/gsm8k_sft.py b/examples/arealite/gsm8k_sft.py index c1d8735..eaacd75 100644 --- a/examples/arealite/gsm8k_sft.py +++ b/examples/arealite/gsm8k_sft.py @@ -35,8 +35,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size): return process_gsm8k_sft_dataset(dataset, tokenizer) -def main_sft(): - config, _ = load_expr_config(sys.argv[1:], SFTConfig) +def main(args): + config, _ = load_expr_config(args, SFTConfig) config: SFTConfig rank = int(os.getenv("RANK")) @@ -121,4 +121,4 @@ def main_sft(): if __name__ == "__main__": - main_sft() + main(sys.argv[1:])