qwen2 grpo works

This commit is contained in:
bowei.fw 2025-07-12 09:44:50 +08:00
parent 0cbddb8aba
commit 888751da38
15 changed files with 336 additions and 235 deletions

View File

@ -9,7 +9,6 @@ from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from arealite.utils.fs import get_user_tmp
from realhf.api.cli_args import OptimizerConfig
@dataclass
@ -81,6 +80,64 @@ class GenerationHyperparameters:
# Train Engine Configs
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training.
Note:
Set type to "empty" for models that won't be trained.
"""
type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
@ -284,7 +341,7 @@ class SGLangConfig:
port,
dist_init_addr: Optional[str] = None,
):
from realhf.base import network, pkg_version, seeding
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
@ -361,6 +418,7 @@ class InferenceEngineConfig:
"the request will not be accepted.",
},
)
enable_rollout_tracing: bool = field(default=False)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
@ -663,8 +721,11 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve
from realhf.base import constants, name_resolve, names
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
)
return cfg, str(config_file)

View File

@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
@dataclass
class LLMRequest:
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
text: Optional[str] = None
input_ids: List[int] = field(default_factory=list)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
@ -28,7 +27,6 @@ class LLMRequest:
@dataclass
class LLMResponse:
# outputs
completions: str
input_tokens: List[int] = field(default_factory=list)
output_tokens: List[int] = field(default_factory=list)
output_logprobs: List[float] = field(default_factory=list)

View File

@ -1,6 +1,7 @@
import gc
import os
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import torch
@ -101,13 +102,23 @@ class FSDPEngine(TrainEngine):
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.init_from_scratch:
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
@ -129,23 +140,14 @@ class FSDPEngine(TrainEngine):
}
# Wrap with FSDP2
tik = time.perf_counter()
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
self.model = model
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
)
self.load(load_meta)
# Set up optimizer
if self.optimizer_config is not None:
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
@ -189,6 +191,7 @@ class FSDPEngine(TrainEngine):
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
self.initialized = True
@ -300,7 +303,9 @@ class FSDPEngine(TrainEngine):
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
)
else:
raise ValueError(f"Unknown weight update type {meta.type}")

View File

@ -7,94 +7,16 @@ from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.utils.data import pack_tensor_dict, split_padded_tensor_dict_into_mb_list
from arealite.utils.data import split_padded_tensor_dict_into_mb_list
from arealite.utils.functional import (
calc_entropy,
gather_logprobs,
gather_logprobs_entropy,
masked_normalization,
ppo_actor_loss_fn,
)
from realhf.base import stats_tracker
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,
temperature: float,
eps_clip: float,
c_clip: float | None,
behav_imp_weight_cap: float | None,
):
"""Loss function for actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
input_ids = input_data["input_ids"]
cu_seqlens = input_data["cu_seqlens"]
old_logp = input_data["logprobs"]
advantages = input_data["advantages"]
loss_mask = input_data["loss_mask"]
prox_logp = input_data.get("prox_logp", None)
logits = logits.float()
logits /= temperature
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
loss, stat = ppo_actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=loss_mask.bool(),
clipped_tokens=stat["clip_mask"],
dual_clipped_tokens=stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=stat["importance_weight"],
approx_kl=stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
entropy=entropy.float(),
actor_loss=stat["loss"],
clip_ratio=stat["clip_mask"].float(),
dual_clip_ratio=stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in stat:
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=stat["behave_imp_weight"],
behave_approx_kl=stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = stat["clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
return loss
class PPOActor:
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
@ -126,11 +48,8 @@ class PPOActor:
) -> torch.Tensor | None:
def calc_logprobs(logits, input_data):
logits = logits.float()
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
if temperature is not None:
logits /= temperature
logprobs = gather_logprobs(logits, labels)
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
return logprobs
return self.engine.forward(
@ -159,10 +78,13 @@ class PPOActor:
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
loss_mask = data["prompt_mask"].logical_not().float()
loss_mask = torch.roll(loss_mask, shifts=-1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
# Apply the mask to log probabilities.
old_logp = data["logprobs"]
ref_logp = data["ref_logp"]
if not self.config.recompute_logprob:
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
else:
old_logp = data["logprobs"]
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
ref_logp *= loss_mask
old_logp *= loss_mask
@ -219,6 +141,8 @@ class PPOActor:
data["kl_rewards"] = kl_rewards
data["tot_rewards"] = rewards
data["loss_mask"] = loss_mask
# because we have rolled old_logp by -1
data["logprobs"] = old_logp
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
attn_mask = data["attention_mask"]
@ -284,6 +208,8 @@ class PPOActor:
global_stats.pop(k2)
########## Logging code ends ##########
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
data.pop(key, None)
mb_inputs = split_padded_tensor_dict_into_mb_list(
data,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
@ -322,3 +248,79 @@ class FSDPPPOActor(FSDPEngine):
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,
temperature: float,
eps_clip: float,
c_clip: float | None,
behav_imp_weight_cap: float | None,
):
"""Loss function for actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
input_ids = input_data["input_ids"]
old_logp = input_data["logprobs"]
advantages = input_data["advantages"]
loss_mask = input_data["loss_mask"].bool()
prox_logp = input_data.get("prox_logp", None)
logprobs, entropy = gather_logprobs_entropy(
logits, torch.roll(input_ids, shifts=-1, dims=-1), temperature
)
entropy = entropy.detach()
loss, stat = ppo_actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=loss_mask.bool(),
clipped_tokens=stat["clip_mask"],
dual_clipped_tokens=stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=stat["importance_weight"],
approx_kl=stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
entropy=entropy.float(),
actor_loss=stat["loss"],
clip_ratio=stat["clip_mask"].float(),
dual_clip_ratio=stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in stat:
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=stat["behave_imp_weight"],
behave_approx_kl=stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = stat["clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
return loss

View File

@ -4,6 +4,7 @@ import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@ -34,7 +35,7 @@ if pkg_version.is_available("sglang"):
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.4
ROLLOUT_POLL_WAIT_TIME = 0.1
RID_CACHE_SIZE = 128
@ -53,8 +54,10 @@ class RemoteSGLangEngine(InferenceEngine):
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
if not self.addresses:
raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
self._wait_for_server(addr)
logger.info("Servers are all ready!")
self.server_idx = 0
@ -115,7 +118,7 @@ class RemoteSGLangEngine(InferenceEngine):
traceback.print_exc()
async def _rollout_thread_async(self):
data = None
pending_data = []
rollout_tasks: Dict[str, asyncio.Task] = {}
rid = 0
@ -123,12 +126,14 @@ class RemoteSGLangEngine(InferenceEngine):
try:
while not self.exiting.is_set():
# Load next data from controller
if data is None:
while True:
try:
data, workflow = self.input_queue.get_nowait()
logger.info(f"Get data from puller: {data}")
logger.debug(f"Get data from puller: {data}")
pending_data.append(data)
except Empty:
logger.debug(f"No data from puller stream.")
break
# Check capacity
if dist.is_initialized():
@ -136,59 +141,36 @@ class RemoteSGLangEngine(InferenceEngine):
else:
world_size = 1
cannot_rollout_reason = []
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
can_rollout = len(rollout_tasks) < capacity
if not can_rollout:
cannot_rollout_reason.append(
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
)
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(rollout_tasks)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = self.config.consumer_batch_size
if dist.is_initialized():
consumer_bs //= dist.get_world_size()
expected_version = sample_cnt // consumer_bs
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({consumer_bs}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
if not can_rollout:
logger.debug(
f"Cannot submit new rollouts. "
+ "\n".join(cannot_rollout_reason)
)
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
# Create new rollout task
if can_rollout and data is not None and not self.paused.is_set():
while capacity > 0 and pending_data and not self.paused.is_set():
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
)
rollout_tasks[str(rid)] = task
with self.lock:
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
data = None
# Wait for rollout completion
tasks = list(rollout_tasks.values())
@ -199,8 +181,10 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
await asyncio.sleep(1)
else:
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
await asyncio.sleep(1)
# Collect done results
for task in done:
@ -219,12 +203,13 @@ class RemoteSGLangEngine(InferenceEngine):
with self.lock:
self.rollout_stat.running -= 1
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
except Exception:
traceback.print_exc()
finally:
@ -323,15 +308,11 @@ class RemoteSGLangEngine(InferenceEngine):
# NOTE: rid should NOT be passed in payload
payload = {
"text": req.text,
"input_ids": req.input_ids.copy(),
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
if req.text:
payload["text"] = req.text
else:
payload["input_ids"] = req.input_ids
# Make request
start_time = time.perf_counter()
@ -369,7 +350,6 @@ class RemoteSGLangEngine(InferenceEngine):
)
# Parse response
completions += result["text"]
meta_info = result["meta_info"]
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
@ -384,12 +364,11 @@ class RemoteSGLangEngine(InferenceEngine):
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += result["text"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
latency = time.perf_counter() - start_time
return LLMResponse(
completions=completions,
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
@ -410,10 +389,10 @@ class RemoteSGLangEngine(InferenceEngine):
update_name = names.update_weights_from_disk(
self.config.experiment_name, self.config.trial_name, meta.model_version
)
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
load_timestamp = time.time_ns()
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
load_timestamp = datetime.now().timestamp()
logger.info(
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp):.2f}s"
)
try:
jobs = [
@ -427,7 +406,7 @@ class RemoteSGLangEngine(InferenceEngine):
finally:
loop.close()
logger.info(
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
)
self.set_version(meta.model_version)
else:
@ -478,7 +457,7 @@ class RemoteSGLangEngine(InferenceEngine):
with self.lock:
self.rollout_stat.accepted -= 1
except Empty:
time.sleep(ROLLOUT_POLL_WAIT_TIME)
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:

View File

@ -1,4 +1,3 @@
import argparse
import getpass
import os
import re
@ -15,13 +14,7 @@ from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cf
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import gpu_utils, logging
from realhf.scheduler.client import (
JobException,
JobInfo,
JobState,
SchedulerClient,
SchedulerError,
)
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("Local Scheduler")
@ -286,7 +279,7 @@ def main_local():
if not cfg.server_only:
launcher.submit(
job_name="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} {' '.join(sys.argv[1:])}",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
gpu=alloc_mode.train_world_size,
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
)

View File

@ -5,7 +5,6 @@ import time
import uuid
import pytest
import requests
import torch
from tensordict import TensorDict
@ -34,7 +33,9 @@ def sglang_server():
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(
skip_tokenizer_init=False, model_path=MODEL_PATH, mem_fraction_static=0.3
skip_tokenizer_init=True,
model_path=MODEL_PATH,
mem_fraction_static=0.3,
),
host=HOST,
port=PORT,
@ -59,11 +60,12 @@ async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
tokenizer = load_hf_tokenizer(MODEL_PATH)
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
text="hello! how are you today",
input_ids=tokenizer.encode("hello! how are you today"),
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
@ -74,7 +76,6 @@ async def test_remote_sglang_generate(sglang_server):
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completions, str)
@pytest.mark.parametrize("n_samples", [1, 2, 4])
@ -101,6 +102,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
@ -139,6 +141,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Callable
from typing import Callable
from arealite.api.cli_args import EvaluatorConfig
from arealite.api.io_struct import FinetuneSpec

View File

@ -2,12 +2,24 @@ import torch
@torch.compile
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
def gather_logprobs(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
@torch.compile
def gather_logprobs_entropy(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels, entropy
from typing import Dict, Optional, Tuple
import numpy as np
@ -15,14 +27,6 @@ import torch
import torch.distributed as dist
@torch.compile
@torch.no_grad()
def calc_entropy(logits, cu_seqlens):
probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)
return entropy
@torch.no_grad()
def masked_normalization(
x: torch.Tensor,

View File

@ -51,7 +51,9 @@ class Saver:
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
path = Saver.get_save_checkpoint_path(
self.config, epoch, step, global_step, name
)
weight_format = "hf"
with_optim = False
if self.for_recover:

View File

@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
enable_thinking: bool,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
async def arun_episode(self, engine, data):
text = self.tokenizer.apply_chat_template(
data["messages"], tokenize=False, add_generation_prompt=True
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
text=text,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
@ -42,8 +47,8 @@ class RLVRWorkflow(RolloutWorkflow):
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt=self.tokenizer.decode(input_ids),
completions=self.tokenizer.decode(resp.output_tokens),
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,

View File

@ -9,43 +9,55 @@ cluster:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
seed: 1
total_train_epochs: 1
total_train_epochs: 10
tokenizer_path: ${actor.path}
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: null
max_concurrent_rollouts: 256
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 0
gconfig:
n_samples: 4
min_new_tokens: 0
max_new_tokens: 1024
max_new_tokens: 512
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
init_from_scratch: false
gradient_checkpointing: false
dtype: float16
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-5
weight_decay: 0.05
lr: 1e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.95
eps: 1e-5
lr_scheduler_type: cosine
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.2
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 4
recompute_logprob: true
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
@ -62,7 +74,7 @@ server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: false
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
@ -70,12 +82,12 @@ sglang:
# datasets
train_dataset:
batch_size: 128
batch_size: 1024
shuffle: true
pin_memory: true
valid_dataset:
batch_size: 128
batch_size: 1024
shuffle: true
pin_memory: true
@ -100,8 +112,8 @@ evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: 1
freq_epochs: 1
freq_steps: null
freq_secs: null
stats_logger:

View File

@ -3,12 +3,13 @@ import re
import sys
import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.evaluator import Evaluator
@ -22,7 +23,7 @@ from realhf.base import stats_tracker
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages, "method": "strict"}
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
@ -35,7 +36,7 @@ def get_gsm8k_dataset(split, rank, world_size):
# Adapted from verl.
def extract_solution(solution_str, method="strict"):
def extract_solution(solution_str, method="strict") -> str | None:
assert method in ["strict", "flexible"]
final_answer = None
@ -62,13 +63,16 @@ def extract_solution(solution_str, method="strict"):
return final_answer
def gsm8k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, method, **kwargs
):
sol = extract_solution(solution_str=completions, method=method)
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
from realhf.impl.dataset.math_parser import extract_answer
sol = extract_answer(completions, data_name="math")
ans = extract_solution(solution_str=answer, method="strict")
if sol is None:
return 0
return int(sol == answer)
if ans is None:
return 0
return int(sol.strip() == ans.strip())
def main_grpo():
@ -107,6 +111,8 @@ def main_grpo():
rollout.initialize(None, ft_spec)
eval_rollout = RemoteSGLangEngine(config.rollout)
eval_rollout.initialize(None, ft_spec)
# NOTE: set a large version such that eval does not have any offpolicyness control
eval_rollout.set_version(int(1e12))
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
@ -122,7 +128,10 @@ def main_grpo():
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=gsm8k_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
# Run training.
@ -137,11 +146,16 @@ def main_grpo():
global_step = 0
for epoch in range(total_epochs):
for step, data in enumerate(train_dataloader):
with stats_tracker.record_timing("rollout"):
batch = rollout.rollout(data, workflow=workflow)
batch = batch.to(actor.device)
if config.actor.recompute_logprob:
with stats_tracker.record_timing("recompute_logp"):
batch["logprobs"] = actor.compute_logp(batch)
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
@ -156,6 +170,22 @@ def main_grpo():
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta(
type="disk",
path=os.path.join(config.cluster.fileroot, "update_weights"),
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
rollout.set_version(global_step)
dist.barrier()
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
@ -169,13 +199,13 @@ def main_grpo():
eval_rollout.submit(item, workflow)
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"]
rewards = batch["rewards"].float()
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
rewards.shape[0],
device=rewards.device,
dtype=rewards.dtype,
dtype=torch.bool,
)
)
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
@ -190,11 +220,16 @@ def main_grpo():
logger.commit(epoch, step, global_step, stats)
global_step += 1
break
break
actor.destroy()
if ref is not None:
ref.destroy()
rollout.destroy()
eval_rollout.destroy()
logger.close()
dist.destroy_process_group()
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
import os
import sys
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
@ -115,6 +116,7 @@ def main_sft():
engine.destroy()
logger.close()
dist.destroy_process_group()
if __name__ == "__main__":

View File

@ -193,7 +193,7 @@ class DistributedStatsTracker:
values = self.stats[key]
if key not in self.denominators:
x = sum([x.sum() for x in values])
if reduce_group is not None:
if dist.is_initialized():
dist.all_reduce(x, group=reduce_group)
else:
denominator = self.denominators[key]
@ -205,7 +205,7 @@ class DistributedStatsTracker:
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, 0.0).sum())
x = sum(xs)
if reduce_group is not None:
if dist.is_initialized():
dist.all_reduce(x, group=reduce_group)
return float(x)
@ -221,7 +221,7 @@ class DistributedStatsTracker:
ds.append(d.sum())
x = sum(xs)
d = sum(ds)
if reduce_group is not None:
if dist.is_initialized():
dist.all_reduce(x, group=reduce_group)
dist.all_reduce(d, group=reduce_group)
if d == 0:
@ -237,7 +237,7 @@ class DistributedStatsTracker:
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, float("inf")).min())
x = min(xs)
if reduce_group is not None:
if dist.is_initialized():
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
if torch.isinf(x):
return None
@ -252,7 +252,7 @@ class DistributedStatsTracker:
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, -float("inf")).max())
x = max(xs)
if reduce_group is not None:
if dist.is_initialized():
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
if torch.isinf(x):
return None