22 KiB
Running GRPO on GSM8K Dataset
This guide introduces how AReaLite runs the GRPO algorithm on the GSM8K dataset, using the training script examples/arealite/gsm8k_grpo.py and configuration file examples/arealite/configs/gsm8k_grpo.yaml.
How AReaLite Works
The following figure illustrates the launching and one asynchronous training step of the GRPO algorithm on the GSM8K dataset on AReaLite. Compared with the old AReaL implementation, AReaLite runs inference servers and a SPMD training script instead of a bunch of various workers. In a training step, AReaLite:
- Submits prompts from the dataset to
RemoteSGLangEngine
, who runsRLVRWorkflow
in a streaming manner. - Completes
RLVRWorkflow
by interacting with remoteSGLangServer
instances to generate sequences, and computing rewards with the reward function. - Once there are enough outputs from
RLVRWorkflow
, aggregates them into a data batch for algorithm-specific training engineFSDPPPOActor
. - Computes losses and update weights in
FSDPPPOActor
. - Transfers the updated weights to remote
SGLangServer
instances.
In the following sections, we will walk you through the code to explain concepts and show you how these steps are done in details.
Launching the Experiment
As shown in Quickstart Guide, experiments in AReaLite are launched using standalone launchers with the following commands:
# Local Launcher
python -m arealite.launcher.local <training script> --config <configuration file> <cli args>
# Ray Launcher
python -m arealite.launcher.ray <training script> --config <configuration file> <cli args>
# Slurm Launcher
python -m arealite.launcher.slurm <training script> --config <configuration file> <cli args>
In AReaLite:
- The training script is an SPMD python script that serves as the experiment entry point.
- The launcher runs the training script with its distributed backend (
subprocess
forLocalLauncher
,ray.remote
forRayLauncher
,srun
forSlurmLauncher
). - The launcher also manages inference servers (currently only supporting
SGLangServer
). The number and parallelization strategies (e.g. tensor parallel) are determined by the option allocation_mode. - For distributed launchers (
RayLauncher
andSlurmLauncher
), inference servers run with a wrapper arealite/launcher/sglang_server.py to handle addresses and ports in distributed settings. - After
SGLangServer
instances are started, launchers collect their addresses and ports to set theAREAL_LLM_SERVER_ADDRS
environment variable for training scripts to access these inference servers.
The configuration file is a YAML file that sets the options provided in
arealite/api/cli_args.py. It could be modified via CLI
arguments such as actor.path=Qwen/Qwen3-1.7B
and +sglang.attention_backend=triton
.
The training scripts parse the config with CLI arguments into the config class defined
in arealite/api/cli_args.py.
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
Loading and Preprocessing Dataset
We use the datasets
and torchdata
packages to load and preprocess the dataset into
our dataloader. First, we download openai/gsm8k
from Huggingface and split it by data
parallel ranks, then map it to our desired format:
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
We then prepare training and evaluation dataloaders with torchdata.StatefulDataLoader
:
train_dataloader = torchdata.StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = ...
If you wish to use your own huggingface datasets or datasets on your local storage, please refers to Customization: Dataset for further details.
Rollout
Inference Engine: RemoteSGLangEngine
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
on separate GPUs from those used for training. The RemoteSGLangEngine
acts as a client
that interacts with the servers. RemoteSGLangEngine
runs in a SPMD manner on every
training process, without occupying any GPUs.
RemoteSGLangEngine
provides two APIs, agenerate
and update_weights
. It is worth
mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update
could happen in the middle of generation of one prompt. With that being said, one
output sequence could be generated by multiple versions of models. Let us glimpse into
code of agenerate
and update_weights
for a better understanding.
In update_weights
, the engine first send pause_generation
requests to all inference
servers, notifying them a weight update is about to happen. Upon receiveing
pause_generation
, inference servers will immediately stop generating and respond with
already generated tokens. Then, the engine sends update_weights_from_distributed
(for
NCCL update) or update_weights_from_disk
(for disk update). After the update is
finished, the engine sends continue_generation
to inference server telling them to
start working again.
class RemoteSGLangEngine:
...
def update_weights(self, meta: WeightUpdateMeta):
# `update_weights` is completely async.
# It submits task to a ProcessPoolExecutor and returns a future
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
if meta.type == "nccl":
future = self.executor.submit(
# a function that send `update_weights_from_distributed` request
update_weights_from_distributed,
)
elif meta.type == "disk":
...
def callback(future):
for addr in self.addresses
requests.post(f"http://{addr}/continue_generation")
future.add_done_callback(callback)
return future
agenerate
takes an LLMRequest
with input_ids
of a single prompt and generation
hyperparameters, and returns the final generation result, an LLMResponse
with
output_tokens
and other outputs. Since the generation could be interrupted,
agenerate
iteratively prepares payload, sends requests and receives responses until
the generation finishes.
class RemoteSGLangEngine:
...
async def agenerate(self, req: LLMRequest):
payload = ... # prepare payload for request
# If request is from the same workflow, choose old server
# to allow KVCache reuse. Otherwise choose server in a round
# robin manner.
server_addr = self.choose_server(req)
stop_reason = None
# other outputs are omitted for simplicity
output_tokens = []
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
# Request is interrupted, wait to avoid contention
if stop_reason is not None:
await asyncio.sleep(0.5)
# send request to remote sever
result = await arequest_with_retry(
addr=server_addr,
endpoint="/generate",
payload=payload,
method="POST"
)
output_tokens.extend(result["output_ids"])
# prepare payload for the next request
payload["input_ids"] += results["output_ids"]
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
return LLMResponse(
input_tokens=req.input_ids,
output_tokens=output_tokens,
...
)
RLVRWorkflow
and WorkflowExecutor
The rollout data lifecycle is controlled by an RLVRWorkflow
, which defines how data
progresses from prompts to complete rollout data containing all fields required for
training. Our example shows a single-turn RLVR workflow with a math reward function. The
core logic of the workflow is implemented in an async method arun_episode
, which takes
a prompt, generate answers with RemoteSGLangEngine
, computes rewards, and populates
additional fields to produce finalized training data.
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self, reward_fn, gconfig, tokenizer, ...
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
async def arun_episode(self, engine, data):
# rollout data with inference engine
input_ids = self.tokenizer.apply_chat_template(data["message"], ...)
req = LLMRequest(rid=..., input_ids=input_ids, gconfig=self.gconfig.new(n_samples=1))
resps = await asyncio.gather(
*[engine.agenerate(req) for _ in range(self.gconfig.n_samples)]
)
# post process rollout responses
results = []
for resp in resps:
reward = self.reward_fn(...)
... # other required fields for training
res = dict(
input_ids=...,
rewards=...,
... # other required fields for training
)
results.append(res)
# return padded `self.gconfig.n_samples` samples with prompt `data["message"]`
return concat_padded_tensors(results)
def gsm8k_reward_fn(completions, answer):
...
tokenizer = load_hf_tokenizer(config.tokenizer_path)
workflow = RLVRWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
...
)
WorkflowExecutor
is responsible for managing the data streaming through rollout
workflows, and collates completed rollout data into batched training samples. When
initializing, it launches a rollout thread that runs rollout workflows as asyncio
tasks. The following code shows the simplified version of rollout thread implementation,
which iteratively:
- Checks available capacity. The capacity controls current number of rollout workflows to limit concurrency and data off-policyness.
- If there is capacity left and rollout is not paused for weight update, continuously
obtains data from
input_queue
and createsasyncio
tasks to run the workflows. - Waits for rollout workflows to finish.
- Gathers data from finished workflows and puts them into
output_queue
class WorkflowExecutor:
...
async def _rollout_thread_async(self):
rid = 0
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# Create rollout tasks with data obtained from input_queue
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
capacity -= 1
rid += 1
# Wait for rollout completion
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
# Collect done results, put the results into output queue
for task in done:
traj = await task
task_rid = task.get_name()
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
self.output_queue.put_nowait(traj)
self.rollout_stat.running -= 1
await asyncio.sleep(1)
...
With this rollout thread running, the training script (the main thread) submits prompts
into input_queue
and collates rollout data from output_queue
into training batches
with prepare_batch
(for asynchronous RL) or rollout_batch
(for synchronous RL). The
following code shows the implementation of prepare_batch
:
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
for item in data:
# submit data into input_queue
self.submit(item, workflow=workflow)
try:
# wait for dataloader.batch_size data from output_queue
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
The usage of WorkflowExecutor
in the training script is simple:
inf_engine = RemoteSGLangEngine(config.inf_engine)
rollout = WorkflowExecutor(config.rollout, inf_engine)
rollout.initialize()
eval_rollout = ...
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
# rollout batched training data for current step
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
If you want to use rollout workflows with custom reward functions or agentic tool calling, see Customization: Rollout Workflows for more details.
Training
After obtaining the training batch, we use FSDPPPOActor
to calculate losses and update
weights. Each train engine corresponds to one model, therefore we need an additional
engine for the reference model. Note that torch.distributed
process groups will be
lazily initialized using init_process_group
when the first train engine is
initialized. The initialization of train engine will also load model weights from paths
specified by the configuration.
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
FSDPPPOActor
is a high-level engine with algorithm-specific APIs, such as
compute_logp
,compute_advantages
and ppo_update
. FSDPPPOActor
is powered by the
lower-level train engine FSDPEngine
, which use pytorch FSDP2 to provide basic APIs
for the model such as train_batch
and forward
. The following code shows a GRPO
training step:
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
if ref is not None:
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
actor.compute_advantages(batch)
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
If you want to customize your own training algorithm, see Customize algorithms for more details.
Transferring Weights to Inference Servers
After training, we transfer updated model weights to remote inference servers through
cooperation between FSDPPPOActor
and RemoteSGLangEngine
. We provide options to
transfer model weights from shared storage or NCCL. In our example training script, we
first prepare WeightUpdateMeta
for NCCL backend on all training processes.
# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str(config.allocation_mode), actor
)
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
If you wish to transfer model weights from shared storage, you can use:
weight_update_meta = WeightUpdateMeta.from_disk(config.saver)
After a training step is finished, we transfer new weights from actor engine to remote inference servers with steps shown in the following code:
# 1. Pause rollout on remote inference servers
rollout.pause()
# 2. Send requests to remote servers, tell them to update weights
if dist.get_rank() == 0:
future = rollout.update_weights(weight_update_meta)
# 3. Actor begins to transfer weights
actor.upload_weights(weight_update_meta)
# 4. Wait for remote servers to return after finishing updates
if dist.get_rank() == 0:
future.result()
# 5. Synchronize rollout processes for model version update
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
# 6. Resume rollout on remote inference servers
rollout.resume()
# 7. Set version, ensures versions on actor and rollout engine are identical
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)
Now a complete GRPO training step in AReaLite is done! The core logic of our example training script can be summarized as:
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
if ref is not None:
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
actor.compute_advantages(batch)
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
rollout.pause()
if dist.get_rank() == 0:
future = rollout.update_weights(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
future.result()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)
Utilities
In AReaLite, we provide a wide range of utilities for basic functionalities required for observing and tuning your experiments.
Saver
and Evaluator
Saver
(arealite/utils/saver.py) and Evaluator
(arealite/utils/evaluator.py) manage the frequency
to save and evaluate the model with the train engine.
In our example, we call saver.save
and evaluator.evaluate
after every training step.
these two methods will automatically check if it is time to save or evaluate the model,
according to the experiment configuration.
stats_tracker
stats_tracker
(realhf/base/stats_tracker.py)
gathers training statistics across parallel ranks and reduce them.
- Scalar-type statistics are recorded by
stats_tracker.scalar(key=value)
and will be averaged by the number of scalars with the same key when reduced. - Tensor-type statistics require
denominator
andreduce_type
to decide how to reduce statistics under the same key.
denominator
is a bool tensor that masks the elements in the tensor that we do not want to record.reduce_type
includes average, sum, min and max. By default, the average, min and max are all calculated.
For example, if we want to record the length of sequences with correct and incorrect answers in a training batch:
seqlens = ... # tensor of shape [#seqs,]
reward_score = ... # tensor of shape [#seqs,]
result_denominators = {
"correct_n_seqs": (reward_score > 0).bool(),
"incorrect_n_seqs": (reward_score <= 0).bool(),
}
# register the denominator
stats_tracker.denominator(**result_denominators)
# record the correct and incorrect sequence length
stats_tracker.stat(
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
)
stats_tracker
offers timer context to record time cost of a code block as a scalar.
And there is also a scope context to manage keys of statistics.
with stats_tracker.record_timing("train_step"):
# training step
...
with stats_tracker.scope("A"):
stats_tracker.scalar(c=123) # key="A/c", value=123
with stats_tracker.scope("B"):
stats_tracker.scalar(c=234) # key="A/B/c", value=234
After recording sufficient data, e.g. after a train_batch
is finished,
stats_tracker.export
is called to aggregate all statistics and dump them into a
dictionary.
stats = stats_tracker.export()
StatsLogger
StatsLogger
(arealite/utils/stats_logger.py)
logs gathered training data to recorders like wandb
and tensorboard
on rank 0. In
our example script, after finishing a training step,
logger.commit(epoch, step, global_step, stats)
is called to record all statistics from
stats_tracker
to print them as well as log them into the recorders set by the
configuration.