AReaL/docs/arealite/gsm8k_grpo.md

23 KiB

Running GRPO on GSM8K Dataset

This guide introduces how AReaL-lite runs the GRPO algorithm on the GSM8K dataset, using the training script examples/arealite/gsm8k_grpo.py and configuration file examples/arealite/configs/gsm8k_grpo.yaml.

How AReaL-lite Works

The following figure illustrates the launching and one asynchronous training step of the GRPO algorithm on the GSM8K dataset on AReaL-lite. Compared with the old AReaL implementation, AReaL-lite runs inference servers and a SPMD training script instead of a bunch of various workers. In a training step, AReaL-lite:

  1. Submits prompts from the dataset to RemoteSGLangEngine, who runs RLVRWorkflow in a streaming manner.
  2. Completes RLVRWorkflow by interacting with remote SGLangServer instances to generate sequences, and computing rewards with the reward function.
  3. Once there are enough outputs from RLVRWorkflow, aggregates them into a data batch for algorithm-specific training engine FSDPPPOActor.
  4. Computes losses and update weights in FSDPPPOActor.
  5. Transfers the updated weights to remote SGLangServer instances.

arealite-gsm8k-example

In the following sections, we will walk you through the code to explain concepts and show you how these steps are done in details.

Launching the Experiment

As shown in the quickstart guide, experiments in AReaL-lite are launched using standalone launchers with the following commands:

# Local Launcher
python -m arealite.launcher.local <training script> --config <configuration file> <cli args>
# Ray Launcher
python -m arealite.launcher.ray <training script> --config <configuration file> <cli args>
# Slurm Launcher
python -m arealite.launcher.slurm <training script> --config <configuration file> <cli args>

In AReaL-lite:

  • The training script is an SPMD python script that serves as the experiment entry point.
  • The launcher runs the training script with its distributed backend (subprocess for LocalLauncher, ray.remote for RayLauncher, srun for SlurmLauncher).
  • The launcher also manages inference servers (currently only supporting SGLangServer). The number and parallelization strategies (e.g. tensor parallel) are determined by the option allocation_mode.
  • For distributed launchers (RayLauncher and SlurmLauncher), inference servers run with a wrapper arealite/launcher/sglang_server.py to handle addresses and ports in distributed settings.
  • After SGLangServer instances are started, launchers collect their addresses and ports to set the AREAL_LLM_SERVER_ADDRS environment variable for training scripts to access these inference servers.

The configuration file is a YAML file that sets the options provided in arealite/api/cli_args.py. It could be modified via CLI arguments such as actor.path=Qwen/Qwen3-1.7B and +sglang.attention_backend=triton. The training scripts parse the config with CLI arguments into the config class defined in arealite/api/cli_args.py.

config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig

Loading and Preprocessing Dataset

We use the datasets and torchdata packages to load and preprocess the dataset into our dataloader. First, we download openai/gsm8k from Hugging Face and split it by data parallel ranks, then map it to our desired format:

def process_gsm8k_rl_dataset(dataset: Dataset):
    def process(sample):
        messages = [{"role": "user", "content": sample["question"]}]
        return {"messages": messages}
    dataset = dataset.map(process).remove_columns(["question"])
    return dataset

def get_gsm8k_dataset(split, rank, world_size):
    dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
    dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
    return process_gsm8k_rl_dataset(dataset)

We then prepare training and evaluation dataloaders with torchdata.StatefulDataLoader:

train_dataloader = torchdata.StatefulDataLoader(
    get_gsm8k_dataset("train", rank, world_size),
    batch_size=config.train_dataset.batch_size // world_size,
    shuffle=config.train_dataset.shuffle,
    num_workers=config.train_dataset.num_workers,
    collate_fn=lambda x: x,
    drop_last=config.train_dataset.drop_last,
)
valid_dataloader = ...

If you wish to use your own huggingface datasets or datasets on your local storage, please refers to Customization: Dataset for further details.

Rollout

Inference Engine: RemoteSGLangEngine

In AReaL-lite, generation tasks are offloaded to remote inference servers, which operate on separate GPUs from those used for training. The RemoteSGLangEngine acts as a client that interacts with the servers. RemoteSGLangEngine runs in a SPMD manner on every training process, without occupying any GPUs.

RemoteSGLangEngine provides two core APIs that access the remote servers, agenerate and update_weights_async. It is worth mentioning that, in asynchronous RL experiment in AReaL-lite, inference-side weight update could happen in the middle of generation of one prompt. With that being said, one output sequence could be generated by multiple versions of models. Let us glimpse into code of agenerate and update_weights_async for a better understanding.

In update_weights_async, the engine first send pause_generation requests to all inference servers, notifying them a weight update is about to happen. Upon receiveing pause_generation, inference servers will immediately stop generating and respond with already generated tokens. Then, the engine sends update_weights_from_distributed (for NCCL update) or update_weights_from_disk (for disk update). After the update is finished, the engine sends continue_generation to inference server telling them to start working again.

class RemoteSGLangEngine:
    ...
    def update_weights_async(self, meta: WeightUpdateMeta):
        # `update_weights_async` is completely async.
        # It submits task to a ProcessPoolExecutor and returns a future
        for addr in self.addresses:
            res = requests.post(f"http://{addr}/pause_generation")
        if meta.type == "nccl":
            future = self.executor.submit(
                # a function that send `update_weights_from_distributed` request
                update_weights_from_distributed,
            )
        elif meta.type == "disk":
            ...

        def callback(future):
            for addr in self.addresses
                requests.post(f"http://{addr}/continue_generation")

        future.add_done_callback(callback)
        return future

agenerate takes an LLMRequest with input_ids of a single prompt and generation hyperparameters, and returns the final generation result, an LLMResponse with output_tokens and other outputs. Since the generation could be interrupted, agenerate iteratively prepares payload, sends requests and receives responses until the generation finishes.

class RemoteSGLangEngine:
    ...
    async def agenerate(self, req: LLMRequest):
        payload = ... # prepare payload for request
        # If request is from the same workflow, choose old server
        # to allow KVCache reuse. Otherwise choose server in a round
        # robin manner.
        server_addr = self.choose_server(req)
        stop_reason = None
        # other outputs are omitted for simplicity
        output_tokens = []
        while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
            # Request is interrupted, wait to avoid contention
            if stop_reason is not None:
                await asyncio.sleep(0.5)
            # send request to remote sever
            result = await arequest_with_retry(
                addr=server_addr,
                endpoint="/generate",
                payload=payload,
                method="POST"
            )
            output_tokens.extend(result["output_ids"])
            # prepare payload for the next request
            payload["input_ids"] += results["output_ids"]
            payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
        return LLMResponse(
            input_tokens=req.input_ids,
            output_tokens=output_tokens,
            ...
        )

The InferenceEngine class is designed to be extensible, supporting not just SGLang but also other backends like vLLM. While different inference engines may be used, the rollout management logic remains consistent. This common functionality is abstracted into the WorkflowExecutor, which will be introduced in the following section.

RLVRWorkflow and WorkflowExecutor

The rollout data lifecycle is controlled by an RLVRWorkflow, which defines how data progresses from prompts to complete rollout data containing all fields required for training. Our example shows a single-turn RLVR workflow with a math reward function. The core logic of the workflow is implemented in an async method arun_episode, which takes a prompt, generate answers with RemoteSGLangEngine, computes rewards, and populates additional fields to produce finalized training data.

class RLVRWorkflow(RolloutWorkflow):
    def __init__(
        self, reward_fn, gconfig, tokenizer, ...
    ):
        self.reward_fn = reward_fn
        self.gconfig = gconfig
        self.tokenizer = tokenizer

    async def arun_episode(self, engine, data):
        # rollout data with inference engine
        input_ids = self.tokenizer.apply_chat_template(data["message"], ...)
        req = LLMRequest(rid=..., input_ids=input_ids, gconfig=self.gconfig.new(n_samples=1))
        resps = await asyncio.gather(
            *[engine.agenerate(req) for _ in range(self.gconfig.n_samples)]
        )
        # post process rollout responses
        results = []
        for resp in resps:
            reward = self.reward_fn(...)
            ... # other required fields for training
            res = dict(
                input_ids=...,
                rewards=...,
                ... # other required fields for training
            )
            results.append(res)
        # return padded `self.gconfig.n_samples` samples with prompt `data["message"]`
        return concat_padded_tensors(results)

def gsm8k_reward_fn(completions, answer):
    ...

tokenizer = load_hf_tokenizer(config.tokenizer_path)
workflow = RLVRWorkflow(
    reward_fn=gsm8k_reward_fn,
    gconfig=config.gconfig,
    tokenizer=tokenizer,
    ...
)

WorkflowExecutor is responsible for managing the data streaming through rollout workflows, and collates completed rollout data into batched training samples. When initializing, it launches a rollout thread that runs rollout workflows as asyncio tasks. The following code shows the simplified version of rollout thread implementation, which iteratively:

  • Checks available capacity. The capacity controls current number of rollout workflows to limit concurrency and data off-policyness (The difference between the model version used by generation and the model version updated by the trainer).
  • If there is capacity left and rollout is not paused for weight update, continuously obtains data from input_queue and creates asyncio tasks to run the workflows.
  • Waits for rollout workflows to finish.
  • Gathers data from finished workflows and puts them into output_queue
class WorkflowExecutor:
    ...
    async def _rollout_thread_async(self):
        rid = 0
        try:
            while not self.exiting.is_set():
                # Check capacity
                capacity = self.get_capacity()
                # Create rollout tasks with data obtained from input_queue
                while (
                    capacity > 0
                    and not self.paused.is_set()
                    and self.input_queue.qsize() > 0
                ):
                    data, workflow = self.input_queue.get_nowait()
                    task = asyncio.create_task(
                        workflow.arun_episode(self, data), name=str(rid)
                    )
                    rollout_tasks[str(rid)] = task
                    self.rollout_stat.submitted += 1
                    self.rollout_stat.running += 1
                    capacity -= 1
                    rid += 1
                # Wait for rollout completion
                tasks = list(rollout_tasks.values())
                completed_tasks = []
                if tasks:
                    completed_tasks, _ = await asyncio.wait(
                        tasks,
                        timeout=ROLLOUT_POLL_WAIT_TIME,
                        return_when=asyncio.FIRST_COMPLETED,
                    )
                # Collect done results, put the results into output queue
                for task in completed_tasks:
                    traj = await task
                    task_rid = task.get_name()
                    rollout_tasks.pop(task_rid)
                    self.rollout_stat.accepted += 1
                    self.output_queue.put_nowait(traj)
                    self.rollout_stat.running -= 1
                await asyncio.sleep(1)
    ...

With this rollout thread running, the training script (the main thread) submits prompts into input_queue and collates rollout data from output_queue into training batches with prepare_batch (for asynchronous RL) or rollout_batch (for synchronous RL). The following code shows the implementation of prepare_batch:

def prepare_batch(
    self,
    dataloader: StatefulDataLoader,
    workflow: "RolloutWorkflow",
):
    if not hasattr(self, "data_generator"):
        self.data_generator = itertools.cycle(dataloader)
    assert dataloader.batch_size is not None
    while True:
        # Submit at least two batches to allow maximum overlap
        if (
            self.get_capacity() + dataloader.batch_size > 0
            and self.input_queue.qsize() + dataloader.batch_size
            < self.input_queue.maxsize
        ):
            data = next(self.data_generator)
            for item in data:
                # submit data into input_queue
                self.submit(item, workflow=workflow)
        try:
            # wait for dataloader.batch_size data from output_queue
            return self.wait(dataloader.batch_size, timeout=1)
        except TimeoutError:
            pass

The RemoteSGLangEngine exposes rollout_batch and prepare_batch by calling them in the workflow executor:

class RemoteSGLangEngine(InferenceEngine):
    ...
    def prepare_batch(self, *args, **kwargs):
        return self.workflow_executor.prepare_batch(*args, **kwargs)

The usage of RemoteSGLangEngine in the training script is simple:

rollout = RemoteSGLangEngine(config.inf_engine)
rollout.initialize()
eval_rollout = ...

data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
    # rollout batched training data for current step
    if config.async_training:
        batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
    else:
        batch = rollout.rollout_batch(next(data_generator), workflow=workflow)

If you want to use rollout workflows with custom reward functions or agentic tool calling, see Customization: Rollout Workflows for more details.

Training

After obtaining the training batch, we use FSDPPPOActor to calculate losses and update weights. Each train engine corresponds to one model, therefore we need an additional engine for the reference model. Note that torch.distributed process groups will be lazily initialized using init_process_group when the first train engine is initialized. The initialization of train engine will also load model weights from paths specified by the configuration.

actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
    ref = FSDPPPOActor(config=config.ref)
    ref.initialize(None, ft_spec)

FSDPPPOActor is a high-level engine with algorithm-specific APIs, such as compute_logp,compute_advantages and ppo_update. FSDPPPOActor is powered by the lower-level train engine FSDPEngine, which use pytorch FSDP2 to provide basic APIs for the model such as train_batch and forward. The following code shows a GRPO training step:

logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
if ref is not None:
    batch["ref_logp"] = ref.compute_logp(batch)
    log_gpu_stats("ref logp")
actor.compute_advantages(batch)
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()

If you want to customize your own training algorithm, see Customize algorithms for more details.

Transferring Weights to Inference Servers

After training, we transfer updated model weights to remote inference servers through cooperation between FSDPPPOActor and RemoteSGLangEngine. We provide options to transfer model weights from shared storage or NCCL. In our example training script, we first prepare WeightUpdateMeta for NCCL backend on all training processes.

# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
    WeightUpdateMeta.from_fsdp_nccl(
        AllocationMode.from_str(config.allocation_mode), actor
    )
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]

If you wish to transfer model weights from shared storage, you can use:

weight_update_meta = WeightUpdateMeta.from_disk(config.saver)

After a training step is finished, we transfer new weights from actor engine to remote inference servers:

  1. The rollout engine needs to stop sending generation requests to remote servers (rollout.pause()) before weight update to avoid server-side congestion.
  2. Since we need to invoke weight update on the trainer engine and remote inference servers at the same time, in the training script, we asynchronously send requests to remote inference servers, and then immediately upload weights on the trainer engine.
rollout.pause()
if dist.get_rank() == 0:
    future = rollout.update_weights_async(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
    future.result()
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)

Now a complete GRPO training step in AReaL-lite is done! The core logic of our example training script can be summarized as:

data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
    if config.async_training:
        batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
    else:
        batch = rollout.rollout_batch(next(data_generator), workflow=workflow)

    logp = actor.compute_logp(batch)
    batch["prox_logp"] = logp
    if ref is not None:
        batch["ref_logp"] = ref.compute_logp(batch)
        log_gpu_stats("ref logp")
    actor.compute_advantages(batch)
    stats = actor.ppo_update(batch)
    actor.step_lr_scheduler()

    rollout.pause()
    if dist.get_rank() == 0:
        future = rollout.update_weights_async(weight_update_meta)
    actor.upload_weights(weight_update_meta)
    if dist.get_rank() == 0:
        future.result()
    rollout.resume()
    actor.set_version(global_step + 1)
    rollout.set_version(global_step + 1)

Utilities

In AReaL-lite, we provide a wide range of utilities for basic functionalities required for observing and tuning your experiments.

Saver and Evaluator

Saver (arealite/utils/saver.py) and Evaluator (arealite/utils/evaluator.py) manage the frequency to save and evaluate the model with the train engine.

In our example, we call saver.save and evaluator.evaluate after every training step. these two methods will automatically check if it is time to save or evaluate the model, according to the experiment configuration.

stats_tracker

stats_tracker (realhf/base/stats_tracker.py) gathers training statistics across parallel ranks and reduce them.

  1. Scalar-type statistics are recorded by stats_tracker.scalar(key=value) and will be averaged by the number of scalars with the same key when reduced.
  2. Tensor-type statistics require denominator and reduce_type to decide how to reduce statistics under the same key.
  • denominator is a bool tensor that masks the elements in the tensor that we do not want to record.
  • reduce_type includes average, sum, min and max. By default, the average, min and max are all calculated.

For example, if we want to record the length of sequences with correct and incorrect answers in a training batch:

seqlens = ... # tensor of shape [#seqs,]
reward_score = ... # tensor of shape [#seqs,]

result_denominators = {
    "correct_n_seqs": (reward_score > 0).bool(),
    "incorrect_n_seqs": (reward_score <= 0).bool(),
}
# register the denominator
stats_tracker.denominator(**result_denominators)
# record the correct and incorrect sequence length
stats_tracker.stat(
    correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
    incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
)

stats_tracker offers timer context to record time cost of a code block as a scalar. And there is also a scope context to manage keys of statistics.

with stats_tracker.record_timing("train_step"):
    # training step
    ...

with stats_tracker.scope("A"):
    stats_tracker.scalar(c=123) # key="A/c", value=123
    with stats_tracker.scope("B"):
        stats_tracker.scalar(c=234) # key="A/B/c", value=234

After recording sufficient data, e.g. after a train_batch is finished, stats_tracker.export is called to aggregate all statistics and dump them into a dictionary.

stats = stats_tracker.export()

StatsLogger

StatsLogger (arealite/utils/stats_logger.py) logs gathered training data to recorders like wandb and tensorboard on rank 0. In our example script, after finishing a training step, logger.commit(epoch, step, global_step, stats) is called to record all statistics from stats_tracker to print them as well as log them into the recorders set by the configuration.

Next Steps