This commit is contained in:
bowei.fw 2025-07-11 11:10:55 +08:00
parent 1bb23f2399
commit 0cbddb8aba
16 changed files with 1198 additions and 129 deletions

View File

@ -127,7 +127,7 @@ class TrainEngineConfig:
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)
@ -220,6 +220,10 @@ class SGLangConfig:
https://github.com/sgl-project/sglang for detailed documentation.
"""
model_path: str = ""
random_seed: int = 1
skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
@ -274,35 +278,27 @@ class SGLangConfig:
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
base_gpu_id,
host,
port,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
):
from realhf.base import network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args["random_seed"] = seeding.get_seed()
if served_model_name is None:
served_model_name = model_path
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
host=host,
model_path=model_path,
port=port,
# Model and tokenizer
tokenizer_path=model_path,
tokenizer_path=sglang_config.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
device="cuda",
served_model_name=served_model_name,
is_embedding=False,
skip_tokenizer_init=skip_tokenizer_init,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
@ -365,17 +361,13 @@ class InferenceEngineConfig:
"the request will not be accepted.",
},
)
# Used by remote inference engines.
server_addrs: List[str] = field(
default_factory=list,
metadata={"help": "List of server addresses for inference."},
)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
setup_timeout: float = field(default=90.0)
request_timeout: float = field(
default=30.0, metadata={"help": "Timeout for HTTP requests."}
default=3600, metadata={"help": "Timeout for HTTP requests."}
)
request_retries: int = field(
default=3, metadata={"help": "Number of retries for failed requests."}
@ -616,6 +608,9 @@ class BaseExperimentConfig:
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
@ -633,7 +628,7 @@ class GRPOConfig(BaseExperimentConfig):
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
def parse_cli_args(argv: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
@ -644,19 +639,26 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
config_file = Path(args.config).absolute()
assert config_file.exists()
# hydra only recognize relative paths
relpath = Path(
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
)
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).rstrip(".yaml"),
overrides=overrides,
)
return cfg, config_file
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
cfg = OmegaConf.merge(default_cfg, cfg)
return cfg
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)

View File

@ -127,10 +127,23 @@ class InferenceEngine(abc.ABC):
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: float) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
raise NotImplementedError()
def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()
def rollout(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:

View File

@ -32,7 +32,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
@ -95,7 +95,7 @@ class FSDPEngine(TrainEngine):
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
@ -323,11 +323,8 @@ class FSDPEngine(TrainEngine):
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].

View File

@ -90,7 +90,7 @@ class HFEngine(TrainEngine):
)
torch.cuda.set_device("cuda:0")
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.engine_config.path,
trust_remote_code=True,

View File

@ -0,0 +1,324 @@
import functools
from typing import Dict, List, Optional
import torch
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.utils.data import pack_tensor_dict, split_padded_tensor_dict_into_mb_list
from arealite.utils.functional import (
calc_entropy,
gather_logprobs,
masked_normalization,
ppo_actor_loss_fn,
)
from realhf.base import stats_tracker
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,
temperature: float,
eps_clip: float,
c_clip: float | None,
behav_imp_weight_cap: float | None,
):
"""Loss function for actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
input_ids = input_data["input_ids"]
cu_seqlens = input_data["cu_seqlens"]
old_logp = input_data["logprobs"]
advantages = input_data["advantages"]
loss_mask = input_data["loss_mask"]
prox_logp = input_data.get("prox_logp", None)
logits = logits.float()
logits /= temperature
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
loss, stat = ppo_actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=loss_mask.bool(),
clipped_tokens=stat["clip_mask"],
dual_clipped_tokens=stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=stat["importance_weight"],
approx_kl=stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
entropy=entropy.float(),
actor_loss=stat["loss"],
clip_ratio=stat["clip_mask"].float(),
dual_clip_ratio=stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in stat:
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=stat["behave_imp_weight"],
behave_approx_kl=stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = stat["clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
return loss
class PPOActor:
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
self.config = config
self.engine = engine
self.reward_bias = config.reward_bias
self.reward_scaling = config.reward_scaling
self.reward_clip = config.reward_clip
self.group_reward_norm = config.group_reward_norm
self.group_adv_norm = config.group_adv_norm
self.group_size = config.group_size
self.kl_ctl = config.kl_ctl
self.adv_norm = config.adv_norm
self.discount = config.discount
self.gae_lambda = config.gae_lambda
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero
self.temperature = config.temperature
@torch.no_grad()
def compute_logp(
self,
data: TensorDict,
temperature: Optional[float] = None,
) -> torch.Tensor | None:
def calc_logprobs(logits, input_data):
logits = logits.float()
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
if temperature is not None:
logits /= temperature
logprobs = gather_logprobs(logits, labels)
return logprobs
return self.engine.forward(
input_=data,
post_hook=calc_logprobs,
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
)
def compute_advantages(self, data: TensorDict) -> None:
bs = data["input_ids"].shape[0]
max_seqlen = data["input_ids"].shape[1]
batch_indices = torch.arange(
bs, device=data["input_ids"].device, dtype=torch.long
)
# Compute rewards using the reward function in synchronous RLVR pipeline.
reward_score = data["rewards"]
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
reward_score = torch.clip(
reward_score, max=self.reward_clip, min=-self.reward_clip
)
if self.group_reward_norm:
for i in range(bs // self.group_size):
s = slice(i * self.group_size, (i + 1) * self.group_size)
r = reward_score[s]
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
loss_mask = data["prompt_mask"].logical_not().float()
loss_mask = torch.roll(loss_mask, shifts=-1)
# Apply the mask to log probabilities.
old_logp = data["logprobs"]
ref_logp = data["ref_logp"]
ref_logp *= loss_mask
old_logp *= loss_mask
# Compute KL-regularized rewards.
attn_mask = data["attention_mask"]
seqlens = attn_mask.sum(-1).long()
seq_no_eos_mask = seqlens == attn_mask.shape[1]
rewards = -self.kl_ctl * (old_logp - ref_logp)
kl_rewards = rewards.clone()
# KL rewards at the next token after eos is zero.
rewards[batch_indices, seqlens - 1] = 0
indices = torch.clip(seqlens - 2, min=0)
if self.mask_no_eos_with_zero:
rewards[batch_indices, indices] += torch.where(
seq_no_eos_mask, 0, reward_score
)
else:
rewards[batch_indices, indices] += reward_score
# Compute GAE.
if "values" not in data:
values = torch.zeros_like(rewards)
else:
values = data["values"]
advantages_reversed = []
lastgaelam = 0
for t in reversed(range(max_seqlen - 1)):
nextvalues = values[:, t + 1]
if t == max_seqlen - 2:
nextvalues *= seq_no_eos_mask
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
advantages_reversed.append(lastgaelam)
advantages_reversed.append(
torch.zeros(bs, dtype=torch.float32, device=values.device)
)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization.
if self.adv_norm:
if self.group_adv_norm:
adv_list = []
for i in range(0, bs, self.group_size):
s = slice(i * self.group_size, (i + 1) * self.group_size)
adv = advantages[s]
m = loss_mask[s]
adv_list.append(masked_normalization(adv, m, all_reduce=False))
advantages = torch.cat(adv_list, 0)
else:
advantages = masked_normalization(advantages, loss_mask)
# Store data in the dict.
data["advantages"] = advantages
data["kl_rewards"] = kl_rewards
data["tot_rewards"] = rewards
data["loss_mask"] = loss_mask
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
attn_mask = data["attention_mask"]
loss_mask = data["loss_mask"]
reward_score = data["rewards"]
seqlens = attn_mask.sum(-1)
all_stats = []
########## Logging code starts ##########
result_denominators = {
"correct_n_seqs": (reward_score > 0).bool(),
"incorrect_n_seqs": (reward_score <= 0).bool(),
}
global_denominators = dict(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**result_denominators,
)
stats_tracker.denominator(**global_denominators)
stats_tracker.stat(
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
)
stats = dict(
advantages=data["advantages"],
kl_rewards=data["kl_rewards"],
final_reward=data["tot_rewards"],
)
stats_tracker.stat(**stats, denominator="n_valid_tokens")
prompt_lens = []
prompt_lens = data["prompt_mask"].sum(-1)
seq_stats = dict(
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
task_reward=reward_score.float(),
prompt_len=prompt_lens.float(),
seq_len=seqlens.float(),
)
stats_tracker.stat(**seq_stats, denominator="n_seqs")
scalars = dict(
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
eps_clip=self.config.eps_clip,
use_prox_logp="prox_logp" in data,
)
if self.config.c_clip is not None:
scalars["c_clip"] = self.config.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
if self.config.behav_imp_weight_cap is not None:
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
stats_tracker.scalar(**scalars)
global_stats = stats_tracker.export()
for k in global_denominators:
keys = list(global_stats.keys())
for k2 in keys:
if k2.endswith(k):
global_stats.pop(k2)
########## Logging code ends ##########
mb_inputs = split_padded_tensor_dict_into_mb_list(
data,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
)
for mb in mb_inputs.mbs:
train_stat = self.engine.train_batch(
mb,
loss_fn=functools.partial(
grpo_loss_fn,
temperature=self.temperature,
eps_clip=self.config.eps_clip,
c_clip=self.config.c_clip,
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
),
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
stats_tracker.scalar(**train_stat)
all_stats.append(stats_tracker.export())
all_stats[0].update(global_stats)
return all_stats
class FSDPPPOActor(FSDPEngine):
def __init__(self, config: PPOActorConfig):
super().__init__(config)
self.actor = PPOActor(config, self)
@torch.no_grad()
def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
return self.actor.compute_logp(*args, **kwargs)
@torch.no_grad()
def compute_advantages(self, *args, **kwargs) -> None:
self.actor.compute_advantages(*args, **kwargs)
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
return self.actor.ppo_update(*args, **kwargs)

View File

@ -1,4 +1,5 @@
import asyncio
import os
import threading
import time
import traceback
@ -7,17 +8,20 @@ from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import aiohttp
import requests
import torch.distributed as dist
from tensordict import TensorDict
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
RolloutStat,
WeightUpdateMeta,
)
from arealite.utils.padding import concat_padded_tensors
from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING:
@ -46,22 +50,48 @@ class RemoteSGLangEngine(InferenceEngine):
# Maintain the addresses for the recent 128 requests
self.rid_queue = []
self.addresses = config.server_addrs
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
if not self.addresses:
raise RuntimeError("No configured SGLang servers.")
for addr in self.addresses:
self._wait_for_server(addr)
self.server_idx = 0
qsize = config.queue_size or config.max_concurrent_rollouts * 10
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self._version = 0
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
def _wait_for_server(self, address):
base_url = f"http://{address}"
tik = time.time()
while time.time() - tik < self.config.setup_timeout:
if self.check_health(base_url):
return
time.sleep(1)
raise RuntimeError("server launch failed")
def check_health(self, base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
@ -119,13 +149,17 @@ class RemoteSGLangEngine(InferenceEngine):
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
expected_version = sample_cnt // self.config.consumer_batch_size
consumer_bs = self.config.consumer_batch_size
if dist.is_initialized():
consumer_bs //= dist.get_world_size()
expected_version = sample_cnt // consumer_bs
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
f"global sample cnt ({sample_cnt}) // batch size ({consumer_bs}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
@ -137,7 +171,7 @@ class RemoteSGLangEngine(InferenceEngine):
)
# Create new rollout task
if can_rollout and data is not None:
if can_rollout and data is not None and not self.paused.is_set():
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
@ -191,6 +225,8 @@ class RemoteSGLangEngine(InferenceEngine):
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
for task in rollout_tasks.values():
@ -236,8 +272,7 @@ class RemoteSGLangEngine(InferenceEngine):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=30,
sock_read=timeout,
sock_connect=timeout,
)
) as session:
if method.upper() == "GET":
@ -252,7 +287,7 @@ class RemoteSGLangEngine(InferenceEngine):
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
return await response.json()
except (
aiohttp.ClientError,
@ -324,7 +359,7 @@ class RemoteSGLangEngine(InferenceEngine):
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# loop until the generation is complete
response = await self.arequest_with_retry(
result = await self.arequest_with_retry(
endpoint="/generate",
payload=payload,
method="POST",
@ -332,7 +367,6 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout,
target_addr=server_addr,
)
result = await response.json()
# Parse response
completions += result["text"]
@ -400,7 +434,7 @@ class RemoteSGLangEngine(InferenceEngine):
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
async def aupdate_weights_from_disk(self, addr, path: str):
response = await self.arequest_with_retry(
res = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
@ -408,7 +442,6 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout,
target_addr=addr,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
@ -422,9 +455,15 @@ class RemoteSGLangEngine(InferenceEngine):
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
@ -432,7 +471,7 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept(result):
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
@ -450,7 +489,7 @@ class RemoteSGLangEngine(InferenceEngine):
self.result_cache[:count],
self.result_cache[count:],
)
return TensorDict.cat(results, dim=0)
return concat_padded_tensors(results)
def rollout(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
@ -458,8 +497,10 @@ class RemoteSGLangEngine(InferenceEngine):
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(
count=len(data),
timeout=self.config.request_timeout,
should_accept=lambda x: True,
)
return self.wait(count=len(data))
def pause(self):
self.paused.set()
def resume(self):
self.paused.clear()

310
arealite/launcher/local.py Normal file
View File

@ -0,0 +1,310 @@
import argparse
import getpass
import os
import re
import signal as signal_module
import subprocess
import sys
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import psutil
from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cfg
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import gpu_utils, logging
from realhf.scheduler.client import (
JobException,
JobInfo,
JobState,
SchedulerClient,
SchedulerError,
)
logger = logging.getLogger("Local Scheduler")
JOB_STATE_TO_PROCESS_STATUS = {
JobState.NOT_FOUND: [],
JobState.PENDING: [psutil.STATUS_PARKED],
JobState.RUNNING: [
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
psutil.STATUS_TRACING_STOP,
psutil.STATUS_WAKING,
psutil.STATUS_WAITING,
psutil.STATUS_LOCKED,
psutil.STATUS_IDLE,
],
JobState.COMPLETED: [
psutil.STATUS_DEAD,
psutil.STATUS_STOPPED,
psutil.STATUS_ZOMBIE,
],
JobState.FAILED: [],
JobState.CANCELLED: [],
}
PROCESS_STATUS_TO_JOB_STATE = {}
for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items():
for process_status in process_statuses:
PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state
def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None):
if signal is None:
signal = signal_module.SIGKILL
if isinstance(signal, str):
signal = getattr(signal_module, signal)
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in children:
terminate_process_and_children(child.pid)
parent.send_signal(signal)
except psutil.NoSuchProcess:
pass
class LocalLauncher:
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
self._jobs: Dict[str, subprocess.Popen] = {}
self._job_counter: Dict[str, int] = defaultdict(int)
self._job_states = {}
self._gpu_counter = 0
self._cuda_devices: List[str] = os.environ.get(
"CUDA_VISIBLE_DEVICES", ",".join(map(str, range(gpu_utils.gpu_count())))
).split(",")
if len(self._cuda_devices) < 1:
raise RuntimeError(
f"Local mode can only run when there is at least one GPU. "
f"CUDA_VISIBLE_DEVICES is currently set to {os.environ['CUDA_VISIBLE_DEVICES']}."
)
@property
def run_name(self):
return f"{self.experiment_name}_{self.trial_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def __del__(self):
self.wait()
def submit_array(
self,
job_name: str,
cmd: str | List[str],
count: int = 1,
gpu: int = 0,
env_vars: Optional[Dict] = None,
):
if env_vars is None:
env_vars = {}
if not isinstance(cmd, list):
cmd = [cmd] * count
offset = self._job_counter[job_name]
for i in range(count):
if gpu > 0:
# Allocate GPUs in a round-robin manner
visible_devices = []
for _ in range(gpu):
available_device_id = self._gpu_counter % len(self._cuda_devices)
self._gpu_counter += 1
visible_devices.append(available_device_id)
env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(
str(self._cuda_devices[j]) for j in visible_devices
)
c = (
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items())
+ " stdbuf -oL "
+ cmd[i]
)
c = f"{c} | tee -a {self.log_path_of(job_name)}"
logger.info("Starting local process with command: %s", c)
process = subprocess.Popen(c, shell=isinstance(c, str))
self._jobs[f"{job_name}/{offset + i}"] = process
self._job_counter[job_name] += 1
def submit(
self,
job_name: str,
cmd: str | List[str],
gpu: int = 0,
env_vars: Optional[Dict] = None,
):
self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars)
def stop(self, job_name, signal=None):
assert any(k.startswith(job_name) for k in self._jobs)
keys = [k for k, p in self._jobs.items() if k.startswith(job_name)]
procs = [p for k, p in self._jobs.items() if k.startswith(job_name)]
logger.info(
f"Stopping local process with signal {signal if signal else 'SIGKILL'}, "
f"pid: {[p.pid for p in procs]}"
)
for p in procs:
terminate_process_and_children(p.pid, signal=signal)
for p in procs:
p.wait()
for k, p in zip(keys, procs):
self._jobs.pop(k)
del p
def stop_all(self, signal=None):
# signal argument is ignored in local stop_all
for name in self._job_counter:
self.stop(name, signal=signal)
def find(self, job_name):
if job_name in self._jobs:
return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost")
else:
return JobInfo(name=job_name, state=JobState.NOT_FOUND)
def find_all(self, job_name_regex=".*"):
rs = []
for name in self._jobs:
if re.fullmatch(job_name_regex, name):
rs.append(self.find(name))
return rs
def wait(
self,
timeout=None,
check_status: Tuple[JobState, ...] = (
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
),
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
update=False,
):
deadline = None if timeout is None else time.time() + timeout
logger.info(
"Waiting for %d local running processes, pids: %s",
len(self._jobs),
" ".join(str(job.pid) for job in self._jobs.values()),
)
left = set(self._jobs.keys())
num_jobs_left = len(left)
while len(left) > 0:
to_remove = []
if len(left) < num_jobs_left:
num_jobs_left = len(left)
logger.info(f"Waiting for {num_jobs_left} jobs.")
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
)
# update job states
for job_name in list(left):
job = self._jobs[job_name]
pid = job.pid
process = psutil.Process(pid)
self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get(
process.status(), JobState.NOT_FOUND
)
for job_name in list(left):
state = self._job_states[job_name]
if state in check_status:
raise JobException(
run_name=self.run_name,
worker_type=job_name.split("/")[0],
host="local",
reason=state,
)
if state in remove_status:
logger.info(f"Job {job_name} is {state}.(Removed)")
left.remove(job_name)
to_remove.append(job_name)
if update:
for k in to_remove:
self._jobs.pop(k)
worker_type = k.split("/")[0]
assert worker_type in self._job_counter
self._job_counter[worker_type] -= 1
if self._job_counter[worker_type] <= 0:
self._job_counter.pop(worker_type)
time.sleep(2)
def main_local():
cfg, _ = parse_cli_args(sys.argv[2:])
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
server_cmd = []
server_addrs = []
if alloc_mode.type_ == AllocationType.DECOUPLED_SGLANG:
base_seed = cfg.sglang.random_seed
cfg.sglang = to_structured_cfg(cfg.sglang, SGLangConfig)
ports = find_free_ports(alloc_mode.gen_dp_size * 2, port_range=(10000, 50000))
host_ip = gethostip()
host = "localhost" if not cfg.sglang.enable_metrics else host_ip
for i in range(alloc_mode.gen_dp_size):
cfg.sglang.random_seed = base_seed + i
cmd = SGLangConfig.build_cmd(
cfg.sglang,
host=host,
tp_size=alloc_mode.gen_tp_size,
base_gpu_id=0,
port=ports[i * 2],
dist_init_addr=f"localhost:{ports[i*2+1]}",
)
server_cmd.append(cmd)
server_addrs.append(f"{host}:{ports[i * 2]}")
else:
raise NotImplementedError()
# Launch inference servers.
launcher.submit_array(
job_name="llm_server",
cmd=server_cmd,
count=alloc_mode.gen_dp_size,
gpu=alloc_mode.gen_pp_size * alloc_mode.gen_tp_size,
)
logger.info(
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
)
# Launch trainer entrypoint
if not cfg.server_only:
launcher.submit(
job_name="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} {' '.join(sys.argv[1:])}",
gpu=alloc_mode.train_world_size,
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
)
try:
launcher.wait(
check_status=(
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
JobState.COMPLETED,
),
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all("SIGTERM")
raise e
if __name__ == "__main__":
main_local()

View File

@ -15,62 +15,41 @@ from arealite.api.cli_args import (
SGLangConfig,
)
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
from arealite.utils import network
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_sglang_engine"
TRIAL_NAME = "trial_0"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13887
DIST_PORT = 15887
PORT, DIST_PORT = network.find_free_ports(2)
HOST = network.gethostip()
def check_server_health(base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(mem_fraction_static=0.3),
model_path=MODEL_PATH,
sglang_config=SGLangConfig(
skip_tokenizer_init=False, model_path=MODEL_PATH, mem_fraction_static=0.3
),
host=HOST,
port=PORT,
tp_size=1,
base_gpu_id=0,
dist_init_addr=f"{HOST}:{DIST_PORT}",
served_model_name=MODEL_PATH,
skip_tokenizer_init=False,
)
# Launch process
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
process = subprocess.Popen(
full_command.split(),
cmd.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < 90:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > 90:
raise RuntimeError("server launch failed")
yield
process.terminate()
@ -80,7 +59,7 @@ async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
@ -109,7 +88,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
max_concurrent_rollouts=2,
consumer_batch_size=2,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -147,7 +126,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
consumer_batch_size=bs,
max_head_offpolicyness=ofp,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -220,7 +199,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
inf_engine = RemoteSGLangEngine(config)
# test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk")

View File

@ -8,7 +8,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
)
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements

View File

@ -290,13 +290,13 @@ class MicroBatchList:
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
def split_packed_tensor_dict_into_mb_list(
def split_padded_tensor_dict_into_mb_list(
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
) -> MicroBatchList:
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
"""Split a padded tensordict into micro-batches based on the attention mask.
Args:
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
data (TensorDict): Dictionary containing padded tensors.
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
MicroBatchList: A structure containing the split micro-batches and metadata.
"""
assert (
"cu_seqlens" in data
), "Input data must be packed and contain 'cu_seqlens' key."
"attention_mask" in data
), "Input data must be padded and contain 'attention_mask' key."
if mb_spec.max_tokens_per_mb is None:
mb_spec = MicroBatchSpec.new(
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
)
cu_seqlens = data["cu_seqlens"]
bs = cu_seqlens.shape[0] - 1
total_lens = int(cu_seqlens[-1])
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
bs = data["attention_mask"].shape[0]
max_seqlen = data["attention_mask"].shape[1]
input_lens = data["attention_mask"].sum(1).cpu().numpy()
# check tensor shape, split only 1d tensors with length "total_lens"
to_split = {}
not_to_split = {}
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
if not torch.is_tensor(value) or value.numel() != total_lens:
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
else:
to_split[key] = value
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
splitted_lens = [
[input_lens[i] for i in group_index] for group_index in group_indices
]
group_n_seqs = [len(x) for x in splitted_lens]
group_lens = [sum(x) for x in splitted_lens]
forward_indices = datapack.flat2d(group_indices)
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
def _split(tensor):
"""Split and pad a tensor based on forward indices and lens."""
# Unpack the sequence
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
unpacked = [tensor[i] for i in range(bs)]
# Reorder according to forward indices
reordered = reorder_list(unpacked, forward_indices)
reordered = torch.cat(reordered)
reordered = torch.stack(reordered)
# Unpack again according to split lens
splitted = unpack_sequence(reordered, lens=group_lens)
splitted = []
offset = 0
for _n_seqs in group_n_seqs:
splitted.append(reordered[offset : offset + _n_seqs])
offset += _n_seqs
return splitted
to_split = dict_map(to_split, lambda x: _split(x))
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
# organize splitted micro batches
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
max_seqlen = max(lens)
lens = torch.tensor(lens, device="cuda")
batch_cu_seqlens = torch.nn.functional.pad(
lens.cumsum(0, dtype=torch.int), (1, 0)
)
results.append(
TensorDict(
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
)
)
results.append(TensorDict(**mb, **not_to_split))
return MicroBatchList(
data=data,
mbs=results,
@ -433,7 +426,7 @@ def pad_mb_list(
# NOTE: GPU page size is 2MB
# Take hidden size 4096 with bf16 dtype as an example,
# the batch size of a page is 256
pad_to_length = (l + 255) // 256 * 256
pad_to_length = (int(l) + 255) // 256 * 256
padded_mb, pad_len = pad_packed_tensor_dict(
mb, pad_to_length, pad_value=pad_value
)

View File

@ -4,10 +4,6 @@ from arealite.api.cli_args import EvaluatorConfig
from arealite.api.io_struct import FinetuneSpec
from realhf.base import timeutil
if TYPE_CHECKING:
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
class Evaluator:
@ -22,8 +18,7 @@ class Evaluator:
def evaluate(
self,
valid_dataloader: "StatefulDataLoader",
evaluate_fn: Callable[["TensorDict"], Any],
evaluate_fn: Callable,
epoch: int,
step: int,
global_step: int,
@ -32,5 +27,4 @@ class Evaluator:
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
for data in valid_dataloader:
evaluate_fn(data)
evaluate_fn()

100
arealite/utils/network.py Normal file
View File

@ -0,0 +1,100 @@
import random
import socket
from typing import List, Set
def gethostname():
return socket.gethostname()
def gethostip():
return socket.gethostbyname(socket.gethostname())
def find_free_ports(
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
) -> List[int]:
"""
Find multiple free ports within a specified range.
Args:
count: Number of free ports to find
port_range: Tuple of (min_port, max_port) to search within
exclude_ports: Set of ports to exclude from search
Returns:
List of free port numbers
Raises:
ValueError: If unable to find requested number of free ports
"""
if exclude_ports is None:
exclude_ports = set()
min_port, max_port = port_range
free_ports = []
attempted_ports = set()
# Calculate available port range
available_range = max_port - min_port + 1 - len(exclude_ports)
if count > available_range:
raise ValueError(
f"Cannot find {count} ports in range {port_range}. "
f"Only {available_range} ports available."
)
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
attempts = 0
while len(free_ports) < count and attempts < max_attempts:
# Generate random port within range
port = random.randint(min_port, max_port)
# Skip if port already attempted or excluded
if port in attempted_ports or port in exclude_ports:
attempts += 1
continue
attempted_ports.add(port)
if is_port_free(port):
free_ports.append(port)
attempts += 1
if len(free_ports) < count:
raise ValueError(
f"Could only find {len(free_ports)} free ports "
f"out of {count} requested after {max_attempts} attempts"
)
return sorted(free_ports)
def is_port_free(port: int) -> bool:
"""
Check if a port is free by attempting to bind to it.
Args:
port: Port number to check
Returns:
True if port is free, False otherwise
"""
# Check TCP
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(("", port))
sock.close()
except OSError:
return False
# Check UDP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.bind(("", port))
sock.close()
return True
except OSError:
return False

View File

@ -0,0 +1,112 @@
experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
n_nodes: 1
n_gpus_per_node: 8
cluster:
fileroot: /tmp/arealite/experiments
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
seed: 1
total_train_epochs: 1
tokenizer_path: ${actor.path}
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: null
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 0
gconfig:
min_new_tokens: 0
max_new_tokens: 1024
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
gradient_checkpointing: false
dtype: float16
mb_spec:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-5
weight_decay: 0.05
beta1: 0.9
beta2: 0.95
eps: 1e-5
lr_scheduler_type: cosine
gradient_clipping: 1.0
backend: fsdp
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 10240
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: false
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
valid_dataset:
batch_size: 128
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: 1
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -16,7 +16,7 @@ model:
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
gradient_checkpointing: false
bf16: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 4096
optimizer:
@ -34,13 +34,11 @@ train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
valid_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
# Utilities
saver:

View File

@ -0,0 +1,201 @@
import os
import re
import sys
import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.rlvr import RLVRWorkflow
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages, "method": "strict"}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
# Adapted from verl.
def extract_solution(solution_str, method="strict"):
assert method in ["strict", "flexible"]
final_answer = None
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
if len(solutions) == 0:
final_answer = None
else:
# take the last solution
final_answer = solutions[-1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def gsm8k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, method, **kwargs
):
sol = extract_solution(solution_str=completions, method=method)
if sol is None:
return 0
return int(sol == answer)
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", rank, world_size),
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.valid_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
eval_rollout = RemoteSGLangEngine(config.rollout)
eval_rollout.initialize(None, ft_spec)
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=gsm8k_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
for epoch in range(total_epochs):
for step, data in enumerate(train_dataloader):
with stats_tracker.record_timing("rollout"):
batch = rollout.rollout(data, workflow=workflow)
batch = batch.to(actor.device)
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
with stats_tracker.record_timing("eval"):
def evaluate_fn():
rollout.pause()
cnt = 0
for data in valid_dataloader:
for item in data:
eval_rollout.submit(item, workflow)
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"]
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
rewards.shape[0],
device=rewards.device,
dtype=rewards.dtype,
)
)
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
rollout.resume()
evaluator.evaluate(
evaluate_fn,
epoch,
step,
global_step,
)
logger.commit(epoch, step, global_step, stats)
global_step += 1
actor.destroy()
if ref is not None:
ref.destroy()
logger.close()
if __name__ == "__main__":
main_grpo()

View File

@ -95,12 +95,16 @@ def main_sft():
with stats_tracker.record_timing("save"):
saver.save(engine, epoch, step, global_step)
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
with stats_tracker.record_timing("eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
def evaluate_fn():
with stats_tracker.scope("sft-eval"):
for data in valid_dataloader:
engine.evaluate_lm(data)
evaluator.evaluate(
valid_dataloader,
engine.evaluate_lm,
evaluate_fn,
epoch,
step,
global_step,