[lite] [refactor] Add GSM8k GRPO example. (#179)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* fix ci

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-16 13:10:26 +08:00 committed by GitHub
parent 4490b117e4
commit e13db01f67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 2056 additions and 417 deletions

View File

@ -92,7 +92,7 @@ def main_grpo():
future.result()
# synchronous rollout
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
# or asynchronous rollout with filtering and off-policyness control
# rollout_batch = rollout.prepare_batch(batch,
# workflow=MyRolloutWorkflow(rollout_config.workflow),
@ -697,7 +697,7 @@ reward = TrainController(Critic())
rollout_controller = RolloutController(...)
for _ in range(epochs):
for _ in range(steps_per_epoch):
data = rollout_controller.rollout(prompt)
data = rollout_controller.rollout_batch(prompt)
data['reward'] = reward.compute_values(data)
...
```

View File

@ -2,7 +2,7 @@ import argparse
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import uvloop
@ -12,7 +12,6 @@ from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from arealite.utils.fs import get_user_tmp
from realhf.api.cli_args import OptimizerConfig
@dataclass
@ -84,6 +83,63 @@ class GenerationHyperparameters:
# Train Engine Configs
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training.
Note:
Set type to "empty" for models that won't be trained.
"""
type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
@ -135,10 +191,11 @@ class TrainEngineConfig:
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
disable_dropout: bool = field(default=False)
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)
@ -147,12 +204,94 @@ class TrainEngineConfig:
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
https://github.com/sgl-project/sglang for detailed documentation.
"""
model_path: str = ""
random_seed: int = 1
skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
@ -188,10 +327,8 @@ class SGLangConfig:
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
@ -207,21 +344,19 @@ class SGLangConfig:
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
base_gpu_id,
host,
port,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
):
args = SGLangConfig.build_args(
sglang_config=sglang_config,
model_path=model_path,
tp_size=tp_size,
base_gpu_id=base_gpu_id,
host=host,
port=port,
dist_init_addr=dist_init_addr,
served_model_name=served_model_name,
skip_tokenizer_init=skip_tokenizer_init,
)
# convert to flags
@ -240,48 +375,54 @@ class SGLangConfig:
@staticmethod
def build_args(
sglang_config: "SGLangConfig",
model_path,
tp_size,
base_gpu_id,
host,
port,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
):
from realhf.base import network, pkg_version, seeding
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args["random_seed"] = seeding.get_seed()
if served_model_name is None:
served_model_name = model_path
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
host=host,
model_path=model_path,
port=port,
# Model and tokenizer
tokenizer_path=model_path,
tokenizer_path=sglang_config.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
device="cuda",
served_model_name=served_model_name,
is_embedding=False,
skip_tokenizer_init=skip_tokenizer_init,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=base_gpu_id,
nnodes=1,
node_rank=0,
# initialization addresses and ports
dist_init_addr=dist_init_addr,
**args,
)
if pkg_version.is_version_less("sglang", "0.4.4"):
sglang_version = pkg_version.get_version("sglang")
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
@ -294,8 +435,8 @@ class SGLangConfig:
@dataclass
class InferenceEngineConfig:
experiment_name: str
trial_name: str
experiment_name: str = MISSING
trial_name: str = MISSING
max_concurrent_rollouts: None | int = field(
default=None,
metadata={
@ -318,28 +459,20 @@ class InferenceEngineConfig:
"the request will not be accepted.",
},
)
# Used by remote inference engines.
server_addrs: List[str] = field(
default_factory=list,
metadata={"help": "List of server addresses for inference."},
)
enable_rollout_tracing: bool = field(default=False)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
setup_timeout: float = field(default=90.0)
request_timeout: float = field(
default=30.0, metadata={"help": "Timeout for HTTP requests."}
default=3600, metadata={"help": "Timeout for HTTP requests."}
)
request_retries: int = field(
default=3, metadata={"help": "Number of retries for failed requests."}
)
@dataclass
class SGLangEngineConfig:
pass
@dataclass
class _Timer:
experiment_name: str = MISSING
@ -569,39 +702,58 @@ class BaseExperimentConfig:
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
@dataclass
class GRPOConfig(BaseExperimentConfig):
async_training: bool = field(default=True)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
def parse_cli_args(argv: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args(argv)
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
# hydra only recognize relative paths
relpath = Path(
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
)
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).rstrip(".yaml"),
config_name=str(relpath.name).split(".yaml")[0],
overrides=overrides,
)
return cfg, config_file
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
cfg = OmegaConf.merge(default_cfg, cfg)
return cfg
def load_expr_config(argv: List[str], config_cls):
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve

View File

@ -4,7 +4,9 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.io_struct import (
FinetuneSpec,
@ -40,6 +42,11 @@ class TrainEngine(abc.ABC):
"""Initialize environments for distributed training and load models."""
raise NotImplementedError()
@property
def parallelism_group(self) -> dist.ProcessGroup:
"""The global communication group of this engine."""
raise NotImplementedError()
def get_scheduling_config(self) -> Scheduling:
"""Get the scheduling configuration for the engine, e.g., image, cpu/gpu/memory size."""
raise NotImplementedError()
@ -77,9 +84,9 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Update the model with a batch of data and a loss function."""
raise NotImplementedError()
@ -87,9 +94,9 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function."""
raise NotImplementedError()
@ -97,9 +104,9 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def forward(
self,
input_: Dict,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
@ -127,12 +134,33 @@ class InferenceEngine(abc.ABC):
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: float) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def rollout(
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
"""Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError()
def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
raise NotImplementedError()
def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()

View File

@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
@dataclass
class LLMRequest:
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
text: Optional[str] = None
input_ids: List[int] = field(default_factory=list)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
@ -28,7 +27,6 @@ class LLMRequest:
@dataclass
class LLMResponse:
# outputs
completions: str
input_tokens: List[int] = field(default_factory=list)
output_tokens: List[int] = field(default_factory=list)
output_logprobs: List[float] = field(default_factory=list)

View File

@ -1,6 +1,7 @@
import gc
import os
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import torch
@ -32,7 +33,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
@ -45,9 +46,10 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
)
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
from realhf.base import constants, logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
@ -68,6 +70,8 @@ class FSDPEngine(TrainEngine):
self.cpu_offload = None
# initialization
self.initialized = False
self.own_global_group = False
self._parallelism_group = None
self.weight_update_group_initialized = False
# TODO: Handle the case when WORLD_SIZE is not set in launcher
@ -78,6 +82,11 @@ class FSDPEngine(TrainEngine):
self.model.train(mode=mode)
return self
@property
def parallelism_group(self) -> dist.ProcessGroup:
assert self.initialized
return self._parallelism_group
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
# Initialize distributed enviroments and load model.
assert addr is None, "FSDPEngine does not support remote initialization."
@ -89,29 +98,54 @@ class FSDPEngine(TrainEngine):
"""Initialize distributed communication and model."""
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(backend="nccl")
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
param_dtype=dtype,
reduce_dtype=torch.float32,
cast_forward_inputs=True,
)
@ -129,23 +163,14 @@ class FSDPEngine(TrainEngine):
}
# Wrap with FSDP2
tik = time.perf_counter()
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
self.model = model
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
)
self.load(load_meta)
# Set up optimizer
if self.optimizer_config is not None:
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
@ -189,6 +214,7 @@ class FSDPEngine(TrainEngine):
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
self.initialized = True
@ -199,6 +225,9 @@ class FSDPEngine(TrainEngine):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
dist.destroy_process_group(self.parallelism_group)
if self.own_global_group:
dist.destroy_process_group()
self.initialized = False
def save(self, meta: SaveLoadMeta):
@ -300,7 +329,9 @@ class FSDPEngine(TrainEngine):
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
@ -323,15 +354,19 @@ class FSDPEngine(TrainEngine):
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
@ -356,9 +391,10 @@ class FSDPEngine(TrainEngine):
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)

View File

@ -15,7 +15,7 @@ from transformers import (
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
SaveLoadMeta,
@ -81,7 +81,7 @@ class HFEngine(TrainEngine):
torch.cuda.set_device(local_rank)
self.device = torch.device(f"cuda:{local_rank}")
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,

View File

@ -0,0 +1,334 @@
import functools
from typing import Dict, List, Optional
import torch
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.utils.data import split_padded_tensor_dict_into_mb_list
from arealite.utils.functional import (
gather_logprobs,
gather_logprobs_entropy,
masked_normalization,
ppo_actor_loss_fn,
)
from realhf.base import stats_tracker
class PPOActor:
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
self.config = config
self.engine = engine
self.reward_bias = config.reward_bias
self.reward_scaling = config.reward_scaling
self.reward_clip = config.reward_clip
self.group_reward_norm = config.group_reward_norm
self.group_adv_norm = config.group_adv_norm
self.group_size = config.group_size
self.kl_ctl = config.kl_ctl
self.adv_norm = config.adv_norm
self.discount = config.discount
self.gae_lambda = config.gae_lambda
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero
self.temperature = config.temperature
@torch.no_grad()
def compute_logp(
self,
data: TensorDict,
temperature: Optional[float] = None,
) -> torch.Tensor | None:
def calc_logprobs(logits, input_data):
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
return logprobs
self.engine.eval()
return self.engine.forward(
input_=data,
post_hook=calc_logprobs,
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
)
def compute_advantages(self, data: TensorDict) -> None:
bs = data["input_ids"].shape[0]
max_seqlen = data["input_ids"].shape[1]
batch_indices = torch.arange(
bs, device=data["input_ids"].device, dtype=torch.long
)
# Compute rewards using the reward function in synchronous RLVR pipeline.
reward_score = data["rewards"]
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
reward_score = torch.clip(
reward_score, max=self.reward_clip, min=-self.reward_clip
)
if self.group_reward_norm:
for i in range(bs // self.group_size):
s = slice(i * self.group_size, (i + 1) * self.group_size)
r = reward_score[s]
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
loss_mask = data["loss_mask"].float()
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
# Apply the mask to log probabilities.
if not self.config.use_decoupled_loss and self.config.recompute_logprob:
# Overwrite logprobs produced by the inference engine
old_logp = data["logprobs"] = data["prox_logp"]
else:
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
if not self.config.use_decoupled_loss:
# prox logp not available, use inferenced logp
data["prox_logp"] = old_logp
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
ref_logp *= loss_mask
old_logp *= loss_mask
# Compute KL-regularized rewards.
attn_mask = data["attention_mask"]
seqlens = attn_mask.sum(-1).long()
seq_no_eos_mask = seqlens == attn_mask.shape[1]
rewards = -self.kl_ctl * (old_logp - ref_logp)
kl_rewards = rewards.clone()
# KL rewards at the next token after eos is zero.
rewards[batch_indices, seqlens - 1] = 0
indices = torch.clip(seqlens - 2, min=0)
if self.mask_no_eos_with_zero:
rewards[batch_indices, indices] += torch.where(
seq_no_eos_mask, 0, reward_score
)
else:
rewards[batch_indices, indices] += reward_score
# Compute GAE.
if "values" not in data:
values = torch.zeros_like(rewards)
else:
values = data["values"]
advantages_reversed = []
lastgaelam = 0
for t in reversed(range(max_seqlen - 1)):
nextvalues = values[:, t + 1]
if t == max_seqlen - 2:
nextvalues *= seq_no_eos_mask
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
advantages_reversed.append(lastgaelam)
advantages_reversed.append(
torch.zeros(bs, dtype=torch.float32, device=values.device)
)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization.
if self.adv_norm:
if self.group_adv_norm:
adv_list = []
for i in range(0, bs, self.group_size):
s = slice(i * self.group_size, (i + 1) * self.group_size)
adv = advantages[s]
m = loss_mask[s]
adv_list.append(masked_normalization(adv, m, all_reduce=False))
advantages = torch.cat(adv_list, 0)
else:
advantages = masked_normalization(advantages, loss_mask)
# Store data in the dict.
data["advantages"] = advantages
data["kl_rewards"] = kl_rewards
data["tot_rewards"] = rewards
data["loss_mask"] = loss_mask
# because we have rolled old_logp by -1
data["logprobs"] = old_logp
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
attn_mask = data["attention_mask"]
loss_mask = data["loss_mask"]
reward_score = data["rewards"]
seqlens = attn_mask.sum(-1)
all_stats = []
########## Logging code starts ##########
result_denominators = {
"correct_n_seqs": (reward_score > 0).bool(),
"incorrect_n_seqs": (reward_score <= 0).bool(),
}
global_denominators = dict(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**result_denominators,
)
stats_tracker.denominator(**global_denominators)
stats_tracker.stat(
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
)
stats = dict(
advantages=data["advantages"],
kl_rewards=data["kl_rewards"],
final_reward=data["tot_rewards"],
)
stats_tracker.stat(**stats, denominator="n_valid_tokens")
prompt_lens = []
prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1)
seq_stats = dict(
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
task_reward=reward_score.float(),
prompt_len=prompt_lens.float(),
seq_len=seqlens.float(),
)
stats_tracker.stat(**seq_stats, denominator="n_seqs")
scalars = dict(
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
eps_clip=self.config.eps_clip,
)
if self.config.c_clip is not None:
scalars["c_clip"] = self.config.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
if self.config.behav_imp_weight_cap is not None:
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
stats_tracker.scalar(**scalars)
global_stats = stats_tracker.export(reduce_group=self.engine.parallelism_group)
for k in global_denominators:
keys = list(global_stats.keys())
for k2 in keys:
if k2.endswith(k):
global_stats.pop(k2)
########## Logging code ends ##########
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
data.pop(key, None)
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
self.engine.train()
mb_inputs = split_padded_tensor_dict_into_mb_list(
data,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
)
for mb in mb_inputs.mbs:
train_stat = self.engine.train_batch(
mb,
loss_fn=functools.partial(
grpo_loss_fn,
temperature=self.temperature,
eps_clip=self.config.eps_clip,
c_clip=self.config.c_clip,
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
),
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
stats_tracker.scalar(**train_stat)
all_stats.append(
stats_tracker.export(reduce_group=self.engine.parallelism_group)
)
all_stats[0].update(global_stats)
return all_stats
class FSDPPPOActor(FSDPEngine):
def __init__(self, config: PPOActorConfig):
super().__init__(config)
self.actor = PPOActor(config, self)
@torch.no_grad()
def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
return self.actor.compute_logp(*args, **kwargs)
@torch.no_grad()
def compute_advantages(self, *args, **kwargs) -> None:
self.actor.compute_advantages(*args, **kwargs)
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,
temperature: float,
eps_clip: float,
c_clip: float | None,
behav_imp_weight_cap: float | None,
):
"""Loss function for actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
input_ids = input_data["input_ids"]
old_logp = input_data["logprobs"]
advantages = input_data["advantages"]
loss_mask = input_data["loss_mask"].bool()
prox_logp = input_data["prox_logp"]
logprobs, entropy = gather_logprobs_entropy(
logits, torch.roll(input_ids, shifts=-1, dims=-1), temperature
)
entropy = entropy.detach()
loss, stat = ppo_actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=loss_mask.bool(),
clipped_tokens=stat["clip_mask"],
dual_clipped_tokens=stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=stat["importance_weight"],
approx_kl=stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
entropy=entropy.float(),
actor_loss=stat["loss"],
clip_ratio=stat["clip_mask"].float(),
dual_clip_ratio=stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in stat:
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=stat["behave_imp_weight"],
behave_approx_kl=stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = stat["clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
return loss

View File

@ -20,7 +20,7 @@ class LMEngine:
return self.engine.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
def evaluate_lm(self, data):
@ -28,7 +28,7 @@ class LMEngine:
self.engine.eval_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
@ -49,26 +49,26 @@ def compute_packed_sft_loss(
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
prompt_mask = input_["prompt_mask"].bool()
loss_mask = input_["loss_mask"].bool()
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
logprobs = torch.where(prompt_mask, 0, logprobs)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
logprobs = torch.where(loss_mask, logprobs, 0)
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
loss = -logprobs.sum() / loss_mask.count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
m.numel() - m.count_nonzero()
)
@ -78,8 +78,8 @@ def compute_packed_sft_loss(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
n_valid_tokens=loss_mask,
prompt_tokens=loss_mask.logical_not(),
)
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")

View File

@ -1,23 +1,30 @@
import asyncio
import os
import random
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List
import aiohttp
import requests
import torch.distributed as dist
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
RolloutStat,
WeightUpdateMeta,
)
from arealite.utils.http import arequest_with_retry
from arealite.utils.padding import concat_padded_tensors
from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING:
@ -30,7 +37,7 @@ if pkg_version.is_available("sglang"):
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.4
ROLLOUT_POLL_WAIT_TIME = 0.1
RID_CACHE_SIZE = 128
@ -46,26 +53,57 @@ class RemoteSGLangEngine(InferenceEngine):
# Maintain the addresses for the recent 128 requests
self.rid_queue = []
self.addresses = config.server_addrs
self.server_idx = 0
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
if not self.addresses:
raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
self._wait_for_server(addr)
logger.info("Servers are all ready!")
qsize = config.queue_size or config.max_concurrent_rollouts * 10
self.server_idx = random.randint(0, len(self.addresses) - 1)
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self._version = 0
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
def _wait_for_server(self, address):
base_url = f"http://{address}"
tik = time.time()
while time.time() - tik < self.config.setup_timeout:
if self.check_health(base_url):
return
time.sleep(1)
raise RuntimeError("server launch failed")
def check_health(self, base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.executor = ProcessPoolExecutor(max_workers=1)
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
def destroy(self):
self.executor.shutdown()
self.exiting.set()
self.rollout_thread.join()
@ -85,79 +123,45 @@ class RemoteSGLangEngine(InferenceEngine):
traceback.print_exc()
async def _rollout_thread_async(self):
data = None
rollout_tasks: Dict[str, asyncio.Task] = {}
pending_data = []
rollout_tasks = self.rollout_tasks
rid = 0
try:
while not self.exiting.is_set():
# Load next data from controller
if data is None:
while True:
try:
data, workflow = self.input_queue.get_nowait()
logger.info(f"Get data from puller: {data}")
logger.debug(f"Get data from puller: {data}")
pending_data.append(data)
except Empty:
logger.debug(f"No data from puller stream.")
break
# Check capacity
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
cannot_rollout_reason = []
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
can_rollout = len(rollout_tasks) < capacity
if not can_rollout:
cannot_rollout_reason.append(
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
expected_version = sample_cnt // self.config.consumer_batch_size
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
if not can_rollout:
logger.debug(
f"Cannot submit new rollouts. "
+ "\n".join(cannot_rollout_reason)
)
capacity = self.get_capacity()
# Create new rollout task
if can_rollout and data is not None:
while capacity > 0 and pending_data and not self.paused.is_set():
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
)
rollout_tasks[str(rid)] = task
with self.lock:
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
data = None
# Wait for rollout completion
tasks = list(rollout_tasks.values())
with self.lock:
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
@ -165,16 +169,19 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
await asyncio.sleep(1)
else:
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
await asyncio.sleep(1)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
try:
self.output_queue.put_nowait(traj)
@ -185,21 +192,25 @@ class RemoteSGLangEngine(InferenceEngine):
with self.lock:
self.rollout_stat.running -= 1
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def choose_server(self) -> str:
if self.config.schedule_policy == "round_robin":
@ -208,65 +219,6 @@ class RemoteSGLangEngine(InferenceEngine):
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def arequest_with_retry(
self,
endpoint: str,
payload: Optional[Dict[str, Any]] = None,
method: str = "POST",
max_retries: Optional[int] = None,
timeout: Optional[float] = None,
retry_delay: float = 1.0,
target_addr: Optional[str] = None,
) -> aiohttp.ClientResponse:
timeout = timeout or self.config.request_timeout
last_exception = None
max_retries = max_retries or self.config.request_retries
# Try with retries
for _ in range(max_retries):
if target_addr:
addr = target_addr
else:
addr = self.choose_server()
base_url = f"http://{addr}"
url = f"{base_url}{endpoint}"
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=30,
sock_read=timeout,
)
) as session:
if method.upper() == "GET":
response = await session.get(url)
elif method.upper() == "POST":
response = await session.post(url, json=payload)
elif method.upper() == "PUT":
response = await session.put(url, json=payload)
elif method.upper() == "DELETE":
response = await session.delete(url)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
asyncio.TimeoutError,
) as e:
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
)
async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
@ -288,15 +240,11 @@ class RemoteSGLangEngine(InferenceEngine):
# NOTE: rid should NOT be passed in payload
payload = {
"text": req.text,
"input_ids": req.input_ids.copy(),
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
if req.text:
payload["text"] = req.text
else:
payload["input_ids"] = req.input_ids
# Make request
start_time = time.perf_counter()
@ -324,18 +272,16 @@ class RemoteSGLangEngine(InferenceEngine):
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# loop until the generation is complete
response = await self.arequest_with_retry(
result = await arequest_with_retry(
addr=self.choose_server(),
endpoint="/generate",
payload=payload,
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
target_addr=server_addr,
)
result = await response.json()
# Parse response
completions += result["text"]
meta_info = result["meta_info"]
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
@ -350,12 +296,15 @@ class RemoteSGLangEngine(InferenceEngine):
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += result["text"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
sample_params["max_new_tokens"] = min(
sample_params["max_new_tokens"],
gconfig.max_new_tokens - len(output_tokens),
)
latency = time.perf_counter() - start_time
return LLMResponse(
completions=completions,
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
@ -365,56 +314,47 @@ class RemoteSGLangEngine(InferenceEngine):
ttft=latency, # Simplified for non-streaming
)
def update_weights(self, meta):
executor = ThreadPoolExecutor(max_workers=1)
return executor.submit(self._update_weights, meta)
def _update_weights(self, meta: WeightUpdateMeta):
def update_weights(self, meta: WeightUpdateMeta):
if meta.type == "disk":
# Update weights from disk
# Wait for model checkpoints of meta.version
update_name = names.update_weights_from_disk(
self.config.experiment_name, self.config.trial_name, meta.model_version
# Use ProcessPool to bypass python GIL for running async coroutines
fut = self.executor.submit(
update_weights_from_disk,
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
self.addresses,
meta.path,
self.config.request_retries,
self.config.request_timeout,
)
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
load_timestamp = time.time_ns()
logger.info(
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
)
try:
jobs = [
self.aupdate_weights_from_disk(addr, meta.path)
for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
)
self.set_version(meta.model_version)
def callback(fut):
self.set_version(meta.model_version)
fut.add_done_callback(callback)
return fut
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
async def aupdate_weights_from_disk(self, addr, path: str):
response = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
target_addr=addr,
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
@ -422,9 +362,15 @@ class RemoteSGLangEngine(InferenceEngine):
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
@ -432,14 +378,14 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept(result):
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except Empty:
time.sleep(ROLLOUT_POLL_WAIT_TIME)
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
@ -450,16 +396,94 @@ class RemoteSGLangEngine(InferenceEngine):
self.result_cache[:count],
self.result_cache[count:],
)
return TensorDict.cat(results, dim=0)
return concat_padded_tensors(results)
def rollout(
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(
count=len(data),
timeout=self.config.request_timeout,
should_accept=lambda x: True,
return self.wait(count=len(data))
def prepare_batch(
self,
data_generator: Iterator,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
assert dataloader.batch_size is not None
while True:
if self.get_capacity() + dataloader.batch_size > 0:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(dataloader)
data = next(data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
def pause(self):
self.paused.set()
def resume(self):
self.paused.clear()
async def aupdate_weights_from_disk(
addr, path: str, request_retries: int, request_timeout: float
):
res = await arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
max_retries=request_retries,
timeout=request_timeout,
)
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
def update_weights_from_disk(
experiment_name,
trial_name,
model_version,
addresses,
path,
request_retries,
request_timeout,
):
async def _fn():
# Wait for model checkpoints of meta.version
update_name = names.update_weights_from_disk(
experiment_name, trial_name, model_version
)
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
load_timestamp = datetime.now().timestamp()
logger.info(
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
)
jobs = [
aupdate_weights_from_disk(
addr,
path=path,
request_retries=request_retries,
request_timeout=request_timeout,
)
for addr in addresses
]
await asyncio.gather(*jobs)
logger.info(
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
)
return asyncio.run(_fn())

307
arealite/launcher/local.py Normal file
View File

@ -0,0 +1,307 @@
import getpass
import os
import re
import signal as signal_module
import subprocess
import sys
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import psutil
from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cfg
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import gpu_utils, logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("Local Scheduler")
JOB_STATE_TO_PROCESS_STATUS = {
JobState.NOT_FOUND: [],
JobState.PENDING: [psutil.STATUS_PARKED],
JobState.RUNNING: [
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
psutil.STATUS_TRACING_STOP,
psutil.STATUS_WAKING,
psutil.STATUS_WAITING,
psutil.STATUS_LOCKED,
psutil.STATUS_IDLE,
],
JobState.COMPLETED: [
psutil.STATUS_DEAD,
psutil.STATUS_STOPPED,
psutil.STATUS_ZOMBIE,
],
JobState.FAILED: [],
JobState.CANCELLED: [],
}
PROCESS_STATUS_TO_JOB_STATE = {}
for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items():
for process_status in process_statuses:
PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state
def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None):
if signal is None:
signal = signal_module.SIGKILL
if isinstance(signal, str):
signal = getattr(signal_module, signal)
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in children:
terminate_process_and_children(child.pid)
parent.send_signal(signal)
except psutil.NoSuchProcess:
pass
class LocalLauncher:
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
self._jobs: Dict[str, subprocess.Popen] = {}
self._job_counter: Dict[str, int] = defaultdict(int)
self._job_states = {}
self._gpu_counter = 0
self._cuda_devices: List[str] = os.environ.get(
"CUDA_VISIBLE_DEVICES", ",".join(map(str, range(gpu_utils.gpu_count())))
).split(",")
if len(self._cuda_devices) < 1:
raise RuntimeError(
f"Local mode can only run when there is at least one GPU. "
f"CUDA_VISIBLE_DEVICES is currently set to {os.environ['CUDA_VISIBLE_DEVICES']}."
)
@property
def run_name(self):
return f"{self.experiment_name}_{self.trial_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def __del__(self):
self.wait()
def submit_array(
self,
job_name: str,
cmd: str | List[str],
count: int = 1,
gpu: int = 0,
env_vars: Optional[Dict] = None,
):
if env_vars is None:
env_vars = {}
if not isinstance(cmd, list):
cmd = [cmd] * count
offset = self._job_counter[job_name]
for i in range(count):
if gpu > 0:
# Allocate GPUs in a round-robin manner
visible_devices = []
for _ in range(gpu):
available_device_id = self._gpu_counter % len(self._cuda_devices)
self._gpu_counter += 1
visible_devices.append(available_device_id)
env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(
str(self._cuda_devices[j]) for j in visible_devices
)
c = (
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items())
+ " stdbuf -oL "
+ cmd[i]
)
c = f"{c} | tee -a {self.log_path_of(job_name)}"
logger.info("Starting local process with command: %s", c)
process = subprocess.Popen(c, shell=isinstance(c, str))
self._jobs[f"{job_name}/{offset + i}"] = process
self._job_counter[job_name] += 1
def submit(
self,
job_name: str,
cmd: str | List[str],
gpu: int = 0,
env_vars: Optional[Dict] = None,
):
self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars)
def stop(self, job_name, signal=None):
assert any(k.startswith(job_name) for k in self._jobs)
keys = [k for k, p in self._jobs.items() if k.startswith(job_name)]
procs = [p for k, p in self._jobs.items() if k.startswith(job_name)]
logger.info(
f"Stopping local process with signal {signal if signal else 'SIGKILL'}, "
f"pid: {[p.pid for p in procs]}"
)
for p in procs:
terminate_process_and_children(p.pid, signal=signal)
for p in procs:
p.wait()
for k, p in zip(keys, procs):
self._jobs.pop(k)
del p
def stop_all(self, signal=None):
# signal argument is ignored in local stop_all
for name in self._job_counter:
self.stop(name, signal=signal)
def find(self, job_name):
if job_name in self._jobs:
return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost")
else:
return JobInfo(name=job_name, state=JobState.NOT_FOUND)
def find_all(self, job_name_regex=".*"):
rs = []
for name in self._jobs:
if re.fullmatch(job_name_regex, name):
rs.append(self.find(name))
return rs
def wait(
self,
timeout=None,
check_status: Tuple[JobState, ...] = (
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
),
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
update=False,
):
deadline = None if timeout is None else time.time() + timeout
logger.info(
"Waiting for %d local running processes, pids: %s",
len(self._jobs),
" ".join(str(job.pid) for job in self._jobs.values()),
)
left = set(self._jobs.keys())
num_jobs_left = len(left)
while len(left) > 0:
to_remove = []
if len(left) < num_jobs_left:
num_jobs_left = len(left)
logger.info(f"Waiting for {num_jobs_left} jobs.")
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
)
# update job states
for job_name in list(left):
job = self._jobs[job_name]
pid = job.pid
process = psutil.Process(pid)
self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get(
process.status(), JobState.NOT_FOUND
)
for job_name in list(left):
state = self._job_states[job_name]
if state in check_status:
raise JobException(
run_name=self.run_name,
worker_type=job_name.split("/")[0],
host="local",
reason=state,
)
if state in remove_status:
logger.info(f"Job {job_name} is {state}.(Removed)")
left.remove(job_name)
to_remove.append(job_name)
if update:
for k in to_remove:
self._jobs.pop(k)
worker_type = k.split("/")[0]
assert worker_type in self._job_counter
self._job_counter[worker_type] -= 1
if self._job_counter[worker_type] <= 0:
self._job_counter.pop(worker_type)
time.sleep(2)
def main_local():
cfg, _ = parse_cli_args(sys.argv[2:])
name_resolve.reconfigure(cfg.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
)
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
server_cmd = []
server_addrs = []
if alloc_mode.type_ == AllocationType.DECOUPLED_SGLANG:
base_seed = cfg.sglang.random_seed
cfg.sglang = to_structured_cfg(cfg.sglang, SGLangConfig)
ports = find_free_ports(alloc_mode.gen_dp_size * 2, port_range=(10000, 50000))
host_ip = gethostip()
host = "localhost" if not cfg.sglang.enable_metrics else host_ip
for i in range(alloc_mode.gen_dp_size):
cfg.sglang.random_seed = base_seed + i
cmd = SGLangConfig.build_cmd(
cfg.sglang,
host=host,
tp_size=alloc_mode.gen_tp_size,
base_gpu_id=0,
port=ports[i * 2],
dist_init_addr=f"localhost:{ports[i*2+1]}",
)
server_cmd.append(cmd)
server_addrs.append(f"{host}:{ports[i * 2]}")
else:
raise NotImplementedError()
# Launch inference servers.
launcher.submit_array(
job_name="llm_server",
cmd=server_cmd,
count=alloc_mode.gen_dp_size,
gpu=alloc_mode.gen_pp_size * alloc_mode.gen_tp_size,
)
logger.info(
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
)
# Launch trainer entrypoint
if not cfg.server_only:
launcher.submit(
job_name="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
gpu=alloc_mode.train_world_size,
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
)
try:
launcher.wait(
check_status=(
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
JobState.COMPLETED,
),
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all("SIGTERM")
raise e
if __name__ == "__main__":
main_local()

View File

@ -5,7 +5,6 @@ import time
import uuid
import pytest
import requests
import torch
from tensordict import TensorDict
@ -15,52 +14,41 @@ from arealite.api.cli_args import (
SGLangConfig,
)
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
from arealite.utils import network
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_sglang_engine"
TRIAL_NAME = "trial_0"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13887
DIST_PORT = 15887
PORT, DIST_PORT = network.find_free_ports(2)
HOST = network.gethostip()
# set a large timeout since we may need to download the model from hub
RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(mem_fraction_static=0.3),
model_path=MODEL_PATH,
sglang_config=SGLangConfig(
skip_tokenizer_init=True,
model_path=MODEL_PATH,
mem_fraction_static=0.3,
),
host=HOST,
port=PORT,
tp_size=1,
base_gpu_id=0,
dist_init_addr=f"{HOST}:{DIST_PORT}",
served_model_name=MODEL_PATH,
skip_tokenizer_init=False,
)
# Launch process
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
process = subprocess.Popen(
full_command.split(),
cmd.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
@ -82,11 +70,12 @@ async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
tokenizer = load_hf_tokenizer(MODEL_PATH)
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
text="hello! how are you today",
input_ids=tokenizer.encode("hello! how are you today"),
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
@ -97,7 +86,6 @@ async def test_remote_sglang_generate(sglang_server):
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completions, str)
@pytest.mark.parametrize("n_samples", [1, 2, 4])
@ -111,7 +99,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
max_concurrent_rollouts=2,
consumer_batch_size=2,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -124,12 +112,13 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
result = engine.rollout([data] * 2, workflow=workflow)
result = engine.rollout_batch([data] * 2, workflow=workflow)
assert isinstance(result, TensorDict)
bs = result.batch_size
assert bs == torch.Size([2 * n_samples])
@ -149,7 +138,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
consumer_batch_size=bs,
max_head_offpolicyness=ofp,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -162,6 +151,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
@ -170,7 +160,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
engine.submit(data, workflow=workflow)
# wait for some time
time.sleep(15)
time.sleep(5)
assert engine.output_queue.qsize() == min(bs * 2, bs * (ofp + 1))
# Update model version
@ -181,7 +171,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
for _ in range(bs * 2):
engine.submit(data, workflow=workflow)
# wait for some time
time.sleep(15)
time.sleep(5)
assert engine.output_queue.qsize() == min(bs * 4, bs * (ofp + 2))
# exit
@ -222,8 +212,9 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
inf_engine = RemoteSGLangEngine(config)
inf_engine.initialize(None, None)
# test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk")
update_weight_meta = WeightUpdateMeta(
@ -233,3 +224,4 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
engine.upload_weights(update_weight_meta)
future.result()
assert inf_engine.get_version() == 100
inf_engine.destroy()

View File

@ -8,7 +8,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
)
@ -28,7 +28,7 @@ def mock_padded_data():
ans_len = int(ans_len)
seq = dict(
input_ids=torch.randint(0, VOCAB_SIZE, size=(prompt_len + ans_len,)),
prompt_mask=torch.tensor([1] * prompt_len + [0] * ans_len),
loss_mask=torch.tensor([0] * prompt_len + [1] * ans_len),
logprobs=torch.randn(prompt_len + ans_len),
position_ids=torch.arange(prompt_len + ans_len),
)
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements

View File

@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
def concat_padded_tensors(
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
) -> Dict[str, torch.Tensor]:
tensor_dicts: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return {}
return TensorDict()
# Find max sequence length across all dictionaries
lens = []
@ -156,7 +156,7 @@ def concat_padded_tensors(
result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result:
result["attention_mask"] = attn_mask
return result
return TensorDict(result, batch_size=[len(lens)])
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
@ -290,13 +290,13 @@ class MicroBatchList:
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
def split_packed_tensor_dict_into_mb_list(
def split_padded_tensor_dict_into_mb_list(
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
) -> MicroBatchList:
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
"""Split a padded tensordict into micro-batches based on the attention mask.
Args:
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
data (TensorDict): Dictionary containing padded tensors.
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
MicroBatchList: A structure containing the split micro-batches and metadata.
"""
assert (
"cu_seqlens" in data
), "Input data must be packed and contain 'cu_seqlens' key."
"attention_mask" in data
), "Input data must be padded and contain 'attention_mask' key."
if mb_spec.max_tokens_per_mb is None:
mb_spec = MicroBatchSpec.new(
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
)
cu_seqlens = data["cu_seqlens"]
bs = cu_seqlens.shape[0] - 1
total_lens = int(cu_seqlens[-1])
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
bs = data["attention_mask"].shape[0]
max_seqlen = data["attention_mask"].shape[1]
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
# check tensor shape, split only 1d tensors with length "total_lens"
to_split = {}
not_to_split = {}
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
if not torch.is_tensor(value) or value.numel() != total_lens:
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
else:
to_split[key] = value
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
splitted_lens = [
[input_lens[i] for i in group_index] for group_index in group_indices
]
group_n_seqs = [len(x) for x in splitted_lens]
group_lens = [sum(x) for x in splitted_lens]
forward_indices = datapack.flat2d(group_indices)
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
def _split(tensor):
"""Split and pad a tensor based on forward indices and lens."""
# Unpack the sequence
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
unpacked = [tensor[i] for i in range(bs)]
# Reorder according to forward indices
reordered = reorder_list(unpacked, forward_indices)
reordered = torch.cat(reordered)
reordered = torch.stack(reordered)
# Unpack again according to split lens
splitted = unpack_sequence(reordered, lens=group_lens)
splitted = []
offset = 0
for _n_seqs in group_n_seqs:
splitted.append(reordered[offset : offset + _n_seqs])
offset += _n_seqs
return splitted
to_split = dict_map(to_split, lambda x: _split(x))
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
# organize splitted micro batches
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
max_seqlen = max(lens)
lens = torch.tensor(lens, device="cuda")
batch_cu_seqlens = torch.nn.functional.pad(
lens.cumsum(0, dtype=torch.int), (1, 0)
)
results.append(
TensorDict(
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
)
)
results.append(TensorDict(**mb, **not_to_split))
return MicroBatchList(
data=data,
mbs=results,
@ -433,7 +426,7 @@ def pad_mb_list(
# NOTE: GPU page size is 2MB
# Take hidden size 4096 with bf16 dtype as an example,
# the batch size of a page is 256
pad_to_length = (l + 255) // 256 * 256
pad_to_length = (int(l) + 255) // 256 * 256
padded_mb, pad_len = pad_packed_tensor_dict(
mb, pad_to_length, pad_value=pad_value
)

31
arealite/utils/device.py Normal file
View File

@ -0,0 +1,31 @@
from typing import Tuple
import torch
import torch.distributed as dist
from realhf.base import logging
logger = logging.getLogger(__file__)
def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]:
"""Get current memory usage."""
assert unit in ["GB", "MB", "KB"]
divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024
mem_allocated = torch.cuda.memory_allocated()
mem_reserved = torch.cuda.memory_reserved()
mem_free, mem_total = torch.cuda.mem_get_info()
mem_used = mem_total - mem_free
mem_allocated = f"{mem_allocated / divisor:.{precision}f}"
mem_reserved = f"{mem_reserved / divisor:.{precision}f}"
mem_used = f"{mem_used / divisor:.{precision}f}"
mem_total = f"{mem_total / divisor:.{precision}f}"
return mem_allocated, mem_reserved, mem_used, mem_total
# Adapted from verl
def log_gpu_stats(head: str, rank: int = 0):
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()
message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}"
logger.info(msg=message)

View File

@ -1,13 +1,9 @@
from typing import TYPE_CHECKING, Any, Callable
from typing import Callable
from arealite.api.cli_args import EvaluatorConfig
from arealite.api.io_struct import FinetuneSpec
from realhf.base import timeutil
if TYPE_CHECKING:
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
class Evaluator:
@ -22,8 +18,7 @@ class Evaluator:
def evaluate(
self,
valid_dataloader: "StatefulDataLoader",
evaluate_fn: Callable[["TensorDict"], Any],
evaluate_fn: Callable,
epoch: int,
step: int,
global_step: int,
@ -32,5 +27,4 @@ class Evaluator:
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
for data in valid_dataloader:
evaluate_fn(data)
evaluate_fn()

View File

@ -1,8 +1,175 @@
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
@torch.compile
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
def _gather_logprobs(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
@torch.compile
def _gather_logprobs_entropy(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels, entropy
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
chunk_size: int = 1024,
):
batch_size = logits.shape[0]
if batch_size <= chunk_size:
return _gather_logprobs(logits, labels, temperature)
log_probs_labels_list = []
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
chunk_logits = logits[i:end_idx]
chunk_labels = labels[i:end_idx]
chunk_log_probs = _gather_logprobs(chunk_logits, chunk_labels, temperature)
log_probs_labels_list.append(chunk_log_probs)
return torch.cat(log_probs_labels_list)
def gather_logprobs_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
chunk_size: int = 1024,
):
batch_size = logits.shape[0]
if batch_size <= chunk_size:
return _gather_logprobs_entropy(logits, labels, temperature)
log_probs_labels_list = []
entropy_list = []
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
chunk_logits = logits[i:end_idx]
chunk_labels = labels[i:end_idx]
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
chunk_logits, chunk_labels, temperature
)
log_probs_labels_list.append(chunk_log_probs)
entropy_list.append(chunk_entropy)
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
@torch.no_grad()
def masked_normalization(
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
dim=None,
unbiased=False,
eps=1e-5,
high_precision=True,
all_reduce=True,
reduce_group=None,
):
dtype = torch.float64 if high_precision else torch.float32
x = x.to(dtype)
if dim is None:
dim = tuple(range(len(x.shape)))
if mask is None:
factor = torch.tensor(
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
)
else:
mask = mask.to(dtype)
x = x * mask
factor = mask.sum(dim, keepdim=True)
x_sum = x.sum(dim=dim, keepdim=True)
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
if dist.is_initialized() and all_reduce:
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(
x_sum_sq,
op=dist.ReduceOp.SUM,
group=reduce_group,
)
mean = x_sum / factor
meansq = x_sum_sq / factor
var = meansq - mean**2
if unbiased:
var *= factor / (factor - 1)
return ((x - mean) / (var.sqrt() + eps)).float()
def ppo_actor_loss_fn(
logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
loss_mask: torch.Tensor,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.Tensor] = None,
behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
denorm_logprobs = (
proximal_logprobs if proximal_logprobs is not None else old_logprobs
)
loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
clip_mask = pg_loss1.detach() < pg_loss2.detach()
pg_loss = torch.max(pg_loss1, pg_loss2)
if c_clip is not None:
assert c_clip > 1.0, c_clip
pg_loss3 = torch.sign(advantages) * c_clip * advantages
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_mask = (
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
if behav_imp_weight_cap is not None
else loss_mask
)
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach()
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - denorm_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat

56
arealite/utils/http.py Normal file
View File

@ -0,0 +1,56 @@
import asyncio
from typing import Any, Dict, Optional
import aiohttp
DEFAULT_RETRIES = 1
DEFAULT_REQUEST_TIMEOUT = 3600
async def arequest_with_retry(
addr: str,
endpoint: str,
payload: Optional[Dict[str, Any]] = None,
method: str = "POST",
max_retries: Optional[int] = None,
timeout: Optional[float] = None,
retry_delay: float = 1.0,
) -> aiohttp.ClientResponse:
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
last_exception = None
max_retries = max_retries or DEFAULT_RETRIES
base_url = f"http://{addr}"
url = f"{base_url}{endpoint}"
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=timeout,
)
) as session:
if method.upper() == "GET":
response = await session.get(url)
elif method.upper() == "POST":
response = await session.post(url, json=payload)
elif method.upper() == "PUT":
response = await session.put(url, json=payload)
elif method.upper() == "DELETE":
response = await session.delete(url)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return await response.json()
except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
asyncio.TimeoutError,
) as e:
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
)

8
arealite/utils/model.py Normal file
View File

@ -0,0 +1,8 @@
import torch
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0

100
arealite/utils/network.py Normal file
View File

@ -0,0 +1,100 @@
import random
import socket
from typing import List, Set
def gethostname():
return socket.gethostname()
def gethostip():
return socket.gethostbyname(socket.gethostname())
def find_free_ports(
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
) -> List[int]:
"""
Find multiple free ports within a specified range.
Args:
count: Number of free ports to find
port_range: Tuple of (min_port, max_port) to search within
exclude_ports: Set of ports to exclude from search
Returns:
List of free port numbers
Raises:
ValueError: If unable to find requested number of free ports
"""
if exclude_ports is None:
exclude_ports = set()
min_port, max_port = port_range
free_ports = []
attempted_ports = set()
# Calculate available port range
available_range = max_port - min_port + 1 - len(exclude_ports)
if count > available_range:
raise ValueError(
f"Cannot find {count} ports in range {port_range}. "
f"Only {available_range} ports available."
)
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
attempts = 0
while len(free_ports) < count and attempts < max_attempts:
# Generate random port within range
port = random.randint(min_port, max_port)
# Skip if port already attempted or excluded
if port in attempted_ports or port in exclude_ports:
attempts += 1
continue
attempted_ports.add(port)
if is_port_free(port):
free_ports.append(port)
attempts += 1
if len(free_ports) < count:
raise ValueError(
f"Could only find {len(free_ports)} free ports "
f"out of {count} requested after {max_attempts} attempts"
)
return sorted(free_ports)
def is_port_free(port: int) -> bool:
"""
Check if a port is free by attempting to bind to it.
Args:
port: Port number to check
Returns:
True if port is free, False otherwise
"""
# Check TCP
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(("", port))
sock.close()
except OSError:
return False
# Check UDP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.bind(("", port))
sock.close()
return True
except OSError:
return False

View File

@ -21,6 +21,18 @@ class Saver:
freq_sec=config.freq_secs,
)
@staticmethod
def get_save_checkpoint_root(
config: SaverConfig,
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
)
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_save_checkpoint_path(
config: SaverConfig,
@ -30,8 +42,7 @@ class Saver:
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
Saver.get_save_checkpoint_root(config, name),
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
@ -51,7 +62,9 @@ class Saver:
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
path = Saver.get_save_checkpoint_path(
self.config, epoch, step, global_step, name
)
weight_format = "hf"
with_optim = False
if self.for_recover:

View File

@ -1,7 +1,7 @@
import getpass
import os
import time
from typing import Dict
from typing import Dict, List
import torch.distributed as dist
import wandb
@ -21,6 +21,8 @@ class StatsLogger:
self.ft_spec = ft_spec
self.init()
self._last_commit_step = 0
def init(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
@ -61,7 +63,7 @@ class StatsLogger:
if self.summary_writer is not None:
self.summary_writer.close()
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
def commit(self, epoch: int, step: int, global_step: int, data: Dict | List[Dict]):
if dist.is_initialized() and dist.get_rank() != 0:
return
self.info(
@ -69,12 +71,17 @@ class StatsLogger:
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
)
self.info("Stats:")
self.print_stats(data)
wandb.log(data, step=global_step)
if self.summary_writer is not None:
for key, val in data.items():
self.summary_writer.add_scalar(f"{key}", val, global_step)
if isinstance(data, Dict):
data = [data]
log_step = max(global_step, self._last_commit_step)
for i, item in enumerate(data):
self.info(f"Stats ({i+1}/{len(data)}):")
self.print_stats(item)
wandb.log(item, step=log_step + i)
if self.summary_writer is not None:
for key, val in item.items():
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
self._last_commit_step = log_step + len(data) - 1
def print_stats(self, stats: Dict[str, float]):
self.info("\n" + tabulate_stats(stats))

View File

@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
enable_thinking: bool,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
async def arun_episode(self, engine, data):
text = self.tokenizer.apply_chat_template(
data["messages"], tokenize=False, add_generation_prompt=True
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
text=text,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
@ -37,13 +42,13 @@ class RLVRWorkflow(RolloutWorkflow):
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0] * resp.input_len + resp.output_logprobs
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt=self.tokenizer.decode(input_ids),
completions=self.tokenizer.decode(resp.output_tokens),
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
@ -51,10 +56,10 @@ class RLVRWorkflow(RolloutWorkflow):
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq)).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)

View File

@ -0,0 +1,129 @@
experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
n_nodes: 1
n_gpus_per_node: 8
cluster:
fileroot: /tmp/arealite/experiments
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
seed: 1
total_train_epochs: 10
tokenizer_path: ${actor.path}
async_training: true
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 256
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: false
gconfig:
n_samples: 4
min_new_tokens: 0
max_new_tokens: 512
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: false
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 10240
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 256
shuffle: true
pin_memory: true
valid_dataset:
batch_size: 256
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -16,7 +16,7 @@ model:
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
gradient_checkpointing: false
bf16: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 4096
optimizer:
@ -34,13 +34,11 @@ train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
valid_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
# Utilities
saver:

View File

@ -0,0 +1,259 @@
import os
import re
import sys
import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.rlvr import RLVRWorkflow
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
# Adapted from verl.
def extract_solution(solution_str, method="strict") -> str | None:
assert method in ["strict", "flexible"]
final_answer = None
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
if len(solutions) == 0:
final_answer = None
else:
# take the last solution
final_answer = solutions[-1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
from realhf.impl.dataset.math_parser import extract_answer
sol = extract_answer(completions, data_name="math")
ans = extract_solution(solution_str=answer, method="strict")
if sol is None:
return 0
if ans is None:
return 0
return int(sol.strip() == ans.strip())
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", rank, world_size),
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.valid_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
eval_rollout = RemoteSGLangEngine(config.rollout)
eval_rollout.initialize(None, ft_spec)
# NOTE: set a large version such that eval does not have any offpolicyness control
eval_rollout.set_version(int(1e12))
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(
data_generator,
train_dataloader,
workflow=workflow,
)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier()
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
log_gpu_stats("recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
log_gpu_stats("compute advantages")
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta(
type="disk",
path=os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step),
),
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
rollout.set_version(global_step + 1)
dist.barrier()
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
with stats_tracker.record_timing("eval"):
def evaluate_fn():
rollout.pause()
cnt = 0
for data in valid_dataloader:
for item in data:
eval_rollout.submit(item, workflow)
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
rewards.shape[0],
device=rewards.device,
dtype=torch.bool,
)
)
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
rollout.resume()
evaluator.evaluate(
evaluate_fn,
epoch,
step,
global_step,
)
logger.commit(epoch, step, global_step, stats)
logger.close()
eval_rollout.destroy()
rollout.destroy()
if ref is not None:
ref.destroy()
actor.destroy()
if __name__ == "__main__":
main_grpo()

View File

@ -22,10 +22,8 @@ def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
prompt_mask = [1] * len(prompt_token) + [0] * (
len(seq_token) - len(prompt_token)
)
return {"input_ids": seq_token, "prompt_mask": prompt_mask}
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
@ -95,22 +93,31 @@ def main_sft():
with stats_tracker.record_timing("save"):
saver.save(engine, epoch, step, global_step)
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
with stats_tracker.record_timing("eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
def evaluate_fn():
with stats_tracker.scope("sft-eval"):
for data in valid_dataloader:
engine.evaluate_lm(data)
evaluator.evaluate(
valid_dataloader,
engine.evaluate_lm,
evaluate_fn,
epoch,
step,
global_step,
)
logger.commit(epoch, step, global_step, stats_tracker.export())
logger.commit(
epoch,
step,
global_step,
stats_tracker.export(reduce_group=engine.parallelism_group),
)
global_step += 1
engine.destroy()
logger.close()
engine.destroy()
if __name__ == "__main__":

View File

@ -41,12 +41,12 @@ class JobException(Exception):
class JobInfo:
name: str
state: JobState
host: str = (
host: Optional[str] = (
None # The host on which the job is/was running. None if the job had not run.
)
submit_time: str = None
start_time: str = None
slurm_id: str = None # Slurm only. The Slurm id of the job.
submit_time: Optional[str] = None
start_time: Optional[str] = None
slurm_id: Optional[int] = None # Slurm only. The Slurm id of the job.
class SchedulerClient:

View File

@ -40,7 +40,7 @@ def execute_shell_command(command: str) -> subprocess.Popen:
)
def apply_sglang_path():
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
@ -75,7 +75,7 @@ def launch_server_cmd(command: str, port: int = 30000):
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_path()
apply_sglang_patch()
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)