mirror of https://github.com/inclusionAI/AReaL
583 lines
23 KiB
Markdown
583 lines
23 KiB
Markdown
# Running GRPO on GSM8K Dataset
|
|
|
|
This guide introduces how AReaL-lite runs the GRPO algorithm on the GSM8K dataset, using
|
|
the training script
|
|
[examples/arealite/gsm8k_grpo.py](../../examples/arealite/gsm8k_grpo.py) and
|
|
configuration file
|
|
[examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml).
|
|
|
|
## How AReaL-lite Works
|
|
|
|
The following figure illustrates the launching and one asynchronous training step of the
|
|
GRPO algorithm on the GSM8K dataset on AReaL-lite. Compared with the old AReaL
|
|
implementation, AReaL-lite runs inference servers and a SPMD training script instead of
|
|
a bunch of various workers. In a training step, AReaL-lite:
|
|
|
|
1. Submits prompts from the dataset to `RemoteSGLangEngine`, who runs `RLVRWorkflow` in
|
|
a streaming manner.
|
|
1. Completes `RLVRWorkflow` by interacting with remote `SGLangServer` instances to
|
|
generate sequences, and computing rewards with the reward function.
|
|
1. Once there are enough outputs from `RLVRWorkflow`, aggregates them into a data batch
|
|
for algorithm-specific training engine `FSDPPPOActor`.
|
|
1. Computes losses and update weights in `FSDPPPOActor`.
|
|
1. Transfers the updated weights to remote `SGLangServer` instances.
|
|
|
|

|
|
|
|
In the following sections, we will walk you through the code to explain concepts and
|
|
show you how these steps are done in details.
|
|
|
|
## Launching the Experiment
|
|
|
|
As shown in the [quickstart guide](../tutorial/quickstart.md), experiments in AReaL-lite
|
|
are launched using standalone launchers with the following commands:
|
|
|
|
```
|
|
# Local Launcher
|
|
python -m arealite.launcher.local <training script> --config <configuration file> <cli args>
|
|
# Ray Launcher
|
|
python -m arealite.launcher.ray <training script> --config <configuration file> <cli args>
|
|
# Slurm Launcher
|
|
python -m arealite.launcher.slurm <training script> --config <configuration file> <cli args>
|
|
```
|
|
|
|
In AReaL-lite:
|
|
|
|
- The **training script** is an SPMD python script that serves as the experiment entry
|
|
point.
|
|
- The launcher runs the training script with its distributed backend (`subprocess` for
|
|
`LocalLauncher`, `ray.remote` for `RayLauncher`, `srun` for `SlurmLauncher`).
|
|
- The launcher also manages inference servers (currently only supporting
|
|
`SGLangServer`). The number and parallelization strategies (e.g. tensor parallel) are
|
|
determined by the option [allocation_mode](../../arealite/api/cli_args.py#L797).
|
|
- For distributed launchers (`RayLauncher` and `SlurmLauncher`), inference servers run
|
|
with a wrapper
|
|
[arealite/launcher/sglang_server.py](../../arealite/launcher/sglang_server.py) to
|
|
handle addresses and ports in distributed settings.
|
|
- After `SGLangServer` instances are started, launchers collect their addresses and
|
|
ports to set the `AREAL_LLM_SERVER_ADDRS` environment variable for training scripts to
|
|
access these inference servers.
|
|
|
|
The **configuration file** is a YAML file that sets the options provided in
|
|
[arealite/api/cli_args.py](../../arealite/api/cli_args.py). It could be modified via CLI
|
|
arguments such as `actor.path=Qwen/Qwen3-1.7B` and `+sglang.attention_backend=triton`.
|
|
The training scripts parse the config with CLI arguments into the config class defined
|
|
in [arealite/api/cli_args.py](../../arealite/api/cli_args.py).
|
|
|
|
```
|
|
config, _ = load_expr_config(args, GRPOConfig)
|
|
config: GRPOConfig
|
|
```
|
|
|
|
## Loading and Preprocessing Dataset
|
|
|
|
We use the `datasets` and `torchdata` packages to load and preprocess the dataset into
|
|
our dataloader. First, we download `openai/gsm8k` from Hugging Face and split it by data
|
|
parallel ranks, then map it to our desired format:
|
|
|
|
```python
|
|
def process_gsm8k_rl_dataset(dataset: Dataset):
|
|
def process(sample):
|
|
messages = [{"role": "user", "content": sample["question"]}]
|
|
return {"messages": messages}
|
|
dataset = dataset.map(process).remove_columns(["question"])
|
|
return dataset
|
|
|
|
def get_gsm8k_dataset(split, rank, world_size):
|
|
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
|
return process_gsm8k_rl_dataset(dataset)
|
|
```
|
|
|
|
We then prepare training and evaluation dataloaders with `torchdata.StatefulDataLoader`:
|
|
|
|
```python
|
|
train_dataloader = torchdata.StatefulDataLoader(
|
|
get_gsm8k_dataset("train", rank, world_size),
|
|
batch_size=config.train_dataset.batch_size // world_size,
|
|
shuffle=config.train_dataset.shuffle,
|
|
num_workers=config.train_dataset.num_workers,
|
|
collate_fn=lambda x: x,
|
|
drop_last=config.train_dataset.drop_last,
|
|
)
|
|
valid_dataloader = ...
|
|
```
|
|
|
|
If you wish to use your own huggingface datasets or datasets on your local storage,
|
|
please refers to [Customization: Dataset](../customization/dataset.md) for further
|
|
details.
|
|
|
|
## Rollout
|
|
|
|
### Inference Engine: `RemoteSGLangEngine`
|
|
|
|
In AReaL-lite, generation tasks are offloaded to remote inference servers, which operate
|
|
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
|
|
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
|
training process, without occupying any GPUs.
|
|
|
|
`RemoteSGLangEngine` provides two core APIs that access the remote servers, `agenerate`
|
|
and `update_weights_async`. It is worth mentioning that, in asynchronous RL experiment
|
|
in AReaL-lite, inference-side weight update could happen **in the middle of** generation
|
|
of one prompt. With that being said, one output sequence could be generated by multiple
|
|
versions of models. Let us glimpse into code of `agenerate` and `update_weights_async`
|
|
for a better understanding.
|
|
|
|
In `update_weights_async`, the engine first send `pause_generation` requests to all
|
|
inference servers, notifying them a weight update is about to happen. Upon receiveing
|
|
`pause_generation`, inference servers will immediately stop generating and respond with
|
|
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
|
|
NCCL update) or `update_weights_from_disk` (for disk update). After the update is
|
|
finished, the engine sends `continue_generation` to inference server telling them to
|
|
start working again.
|
|
|
|
```python
|
|
class RemoteSGLangEngine:
|
|
...
|
|
def update_weights_async(self, meta: WeightUpdateMeta):
|
|
# `update_weights_async` is completely async.
|
|
# It submits task to a ProcessPoolExecutor and returns a future
|
|
for addr in self.addresses:
|
|
res = requests.post(f"http://{addr}/pause_generation")
|
|
if meta.type == "nccl":
|
|
future = self.executor.submit(
|
|
# a function that send `update_weights_from_distributed` request
|
|
update_weights_from_distributed,
|
|
)
|
|
elif meta.type == "disk":
|
|
...
|
|
|
|
def callback(future):
|
|
for addr in self.addresses
|
|
requests.post(f"http://{addr}/continue_generation")
|
|
|
|
future.add_done_callback(callback)
|
|
return future
|
|
```
|
|
|
|
`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation
|
|
hyperparameters, and returns the final generation result, an `LLMResponse` with
|
|
`output_tokens` and other outputs. Since the generation could be interrupted,
|
|
`agenerate` iteratively prepares payload, sends requests and receives responses until
|
|
the generation finishes.
|
|
|
|
```python
|
|
class RemoteSGLangEngine:
|
|
...
|
|
async def agenerate(self, req: LLMRequest):
|
|
payload = ... # prepare payload for request
|
|
# If request is from the same workflow, choose old server
|
|
# to allow KVCache reuse. Otherwise choose server in a round
|
|
# robin manner.
|
|
server_addr = self.choose_server(req)
|
|
stop_reason = None
|
|
# other outputs are omitted for simplicity
|
|
output_tokens = []
|
|
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
|
|
# Request is interrupted, wait to avoid contention
|
|
if stop_reason is not None:
|
|
await asyncio.sleep(0.5)
|
|
# send request to remote sever
|
|
result = await arequest_with_retry(
|
|
addr=server_addr,
|
|
endpoint="/generate",
|
|
payload=payload,
|
|
method="POST"
|
|
)
|
|
output_tokens.extend(result["output_ids"])
|
|
# prepare payload for the next request
|
|
payload["input_ids"] += results["output_ids"]
|
|
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
|
|
return LLMResponse(
|
|
input_tokens=req.input_ids,
|
|
output_tokens=output_tokens,
|
|
...
|
|
)
|
|
|
|
```
|
|
|
|
The `InferenceEngine` class is designed to be extensible, supporting not just SGLang but
|
|
also other backends like vLLM. While different inference engines may be used, the
|
|
rollout management logic remains consistent. This common functionality is abstracted
|
|
into the `WorkflowExecutor`, which will be introduced in the following section.
|
|
|
|
### `RLVRWorkflow` and `WorkflowExecutor`
|
|
|
|
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
|
|
progresses from prompts to complete rollout data containing all fields required for
|
|
training. Our example shows a single-turn RLVR workflow with a math reward function. The
|
|
core logic of the workflow is implemented in an async method `arun_episode`, which takes
|
|
a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates
|
|
additional fields to produce finalized training data.
|
|
|
|
```python
|
|
class RLVRWorkflow(RolloutWorkflow):
|
|
def __init__(
|
|
self, reward_fn, gconfig, tokenizer, ...
|
|
):
|
|
self.reward_fn = reward_fn
|
|
self.gconfig = gconfig
|
|
self.tokenizer = tokenizer
|
|
|
|
async def arun_episode(self, engine, data):
|
|
# rollout data with inference engine
|
|
input_ids = self.tokenizer.apply_chat_template(data["message"], ...)
|
|
req = LLMRequest(rid=..., input_ids=input_ids, gconfig=self.gconfig.new(n_samples=1))
|
|
resps = await asyncio.gather(
|
|
*[engine.agenerate(req) for _ in range(self.gconfig.n_samples)]
|
|
)
|
|
# post process rollout responses
|
|
results = []
|
|
for resp in resps:
|
|
reward = self.reward_fn(...)
|
|
... # other required fields for training
|
|
res = dict(
|
|
input_ids=...,
|
|
rewards=...,
|
|
... # other required fields for training
|
|
)
|
|
results.append(res)
|
|
# return padded `self.gconfig.n_samples` samples with prompt `data["message"]`
|
|
return concat_padded_tensors(results)
|
|
|
|
def gsm8k_reward_fn(completions, answer):
|
|
...
|
|
|
|
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
|
workflow = RLVRWorkflow(
|
|
reward_fn=gsm8k_reward_fn,
|
|
gconfig=config.gconfig,
|
|
tokenizer=tokenizer,
|
|
...
|
|
)
|
|
```
|
|
|
|
`WorkflowExecutor` is responsible for managing the data streaming through rollout
|
|
workflows, and collates completed rollout data into batched training samples. When
|
|
initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
|
|
tasks. The following code shows the simplified version of rollout thread implementation,
|
|
which iteratively:
|
|
|
|
- Checks available capacity. The capacity controls current number of rollout workflows
|
|
to limit concurrency and **data off-policyness** (The difference between the model
|
|
version used by generation and the model version updated by the trainer).
|
|
- If there is capacity left and rollout is not paused for weight update, continuously
|
|
obtains data from `input_queue` and creates `asyncio` tasks to run the workflows.
|
|
- Waits for rollout workflows to finish.
|
|
- Gathers data from finished workflows and puts them into `output_queue`
|
|
|
|
```python
|
|
class WorkflowExecutor:
|
|
...
|
|
async def _rollout_thread_async(self):
|
|
rid = 0
|
|
try:
|
|
while not self.exiting.is_set():
|
|
# Check capacity
|
|
capacity = self.get_capacity()
|
|
# Create rollout tasks with data obtained from input_queue
|
|
while (
|
|
capacity > 0
|
|
and not self.paused.is_set()
|
|
and self.input_queue.qsize() > 0
|
|
):
|
|
data, workflow = self.input_queue.get_nowait()
|
|
task = asyncio.create_task(
|
|
workflow.arun_episode(self, data), name=str(rid)
|
|
)
|
|
rollout_tasks[str(rid)] = task
|
|
self.rollout_stat.submitted += 1
|
|
self.rollout_stat.running += 1
|
|
capacity -= 1
|
|
rid += 1
|
|
# Wait for rollout completion
|
|
tasks = list(rollout_tasks.values())
|
|
completed_tasks = []
|
|
if tasks:
|
|
completed_tasks, _ = await asyncio.wait(
|
|
tasks,
|
|
timeout=ROLLOUT_POLL_WAIT_TIME,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
# Collect done results, put the results into output queue
|
|
for task in completed_tasks:
|
|
traj = await task
|
|
task_rid = task.get_name()
|
|
rollout_tasks.pop(task_rid)
|
|
self.rollout_stat.accepted += 1
|
|
self.output_queue.put_nowait(traj)
|
|
self.rollout_stat.running -= 1
|
|
await asyncio.sleep(1)
|
|
...
|
|
```
|
|
|
|
With this rollout thread running, the training script (the main thread) submits prompts
|
|
into `input_queue` and collates rollout data from `output_queue` into training batches
|
|
with `prepare_batch` (for asynchronous RL) or `rollout_batch` (for synchronous RL). The
|
|
following code shows the implementation of `prepare_batch`:
|
|
|
|
```python
|
|
def prepare_batch(
|
|
self,
|
|
dataloader: StatefulDataLoader,
|
|
workflow: "RolloutWorkflow",
|
|
):
|
|
if not hasattr(self, "data_generator"):
|
|
self.data_generator = itertools.cycle(dataloader)
|
|
assert dataloader.batch_size is not None
|
|
while True:
|
|
# Submit at least two batches to allow maximum overlap
|
|
if (
|
|
self.get_capacity() + dataloader.batch_size > 0
|
|
and self.input_queue.qsize() + dataloader.batch_size
|
|
< self.input_queue.maxsize
|
|
):
|
|
data = next(self.data_generator)
|
|
for item in data:
|
|
# submit data into input_queue
|
|
self.submit(item, workflow=workflow)
|
|
try:
|
|
# wait for dataloader.batch_size data from output_queue
|
|
return self.wait(dataloader.batch_size, timeout=1)
|
|
except TimeoutError:
|
|
pass
|
|
```
|
|
|
|
The `RemoteSGLangEngine` exposes `rollout_batch` and `prepare_batch` by calling them in
|
|
the workflow executor:
|
|
|
|
```python
|
|
class RemoteSGLangEngine(InferenceEngine):
|
|
...
|
|
def prepare_batch(self, *args, **kwargs):
|
|
return self.workflow_executor.prepare_batch(*args, **kwargs)
|
|
```
|
|
|
|
The usage of `RemoteSGLangEngine` in the training script is simple:
|
|
|
|
```python
|
|
rollout = RemoteSGLangEngine(config.inf_engine)
|
|
rollout.initialize()
|
|
eval_rollout = ...
|
|
|
|
data_generator = itertools.cycle(train_dataloader)
|
|
for global_step in range(max_steps):
|
|
# rollout batched training data for current step
|
|
if config.async_training:
|
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
|
else:
|
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
|
```
|
|
|
|
If you want to use rollout workflows with custom reward functions or agentic tool
|
|
calling, see [Customization: Rollout Workflows](../customization/agent.md) for more
|
|
details.
|
|
|
|
## Training
|
|
|
|
After obtaining the training batch, we use `FSDPPPOActor` to calculate losses and update
|
|
weights. Each train engine corresponds to one model, therefore we need an additional
|
|
engine for the reference model. Note that `torch.distributed` process groups will be
|
|
lazily initialized using `init_process_group` when the first train engine is
|
|
initialized. The initialization of train engine will also load model weights from paths
|
|
specified by the configuration.
|
|
|
|
```python
|
|
actor = FSDPPPOActor(config=config.actor)
|
|
actor.initialize(None, ft_spec)
|
|
ref = None
|
|
if config.actor.kl_ctl > 0 and config.ref is not None:
|
|
ref = FSDPPPOActor(config=config.ref)
|
|
ref.initialize(None, ft_spec)
|
|
```
|
|
|
|
`FSDPPPOActor` is a high-level engine with algorithm-specific APIs, such as
|
|
`compute_logp`,`compute_advantages` and `ppo_update`. `FSDPPPOActor` is powered by the
|
|
lower-level train engine `FSDPEngine`, which use **pytorch FSDP2** to provide basic APIs
|
|
for the model such as `train_batch` and `forward`. The following code shows a GRPO
|
|
training step:
|
|
|
|
```python
|
|
logp = actor.compute_logp(batch)
|
|
batch["prox_logp"] = logp
|
|
if ref is not None:
|
|
batch["ref_logp"] = ref.compute_logp(batch)
|
|
log_gpu_stats("ref logp")
|
|
actor.compute_advantages(batch)
|
|
stats = actor.ppo_update(batch)
|
|
actor.step_lr_scheduler()
|
|
```
|
|
|
|
If you want to customize your own training algorithm, see
|
|
[Customize algorithms](../customization/algorithm.md) for more details.
|
|
|
|
## Transferring Weights to Inference Servers
|
|
|
|
After training, we transfer updated model weights to remote inference servers through
|
|
cooperation between `FSDPPPOActor` and `RemoteSGLangEngine`. We provide options to
|
|
transfer model weights from shared storage or NCCL. In our example training script, we
|
|
first prepare `WeightUpdateMeta` for NCCL backend on all training processes.
|
|
|
|
```python
|
|
# NOTE: Weight update meta only requires address and free port of rank 0,
|
|
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
|
# due to `engine.get_param_specs()`.
|
|
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
|
weight_update_meta = [
|
|
WeightUpdateMeta.from_fsdp_nccl(
|
|
AllocationMode.from_str(config.allocation_mode), actor
|
|
)
|
|
]
|
|
dist.broadcast_object_list(weight_update_meta, src=0)
|
|
weight_update_meta = weight_update_meta[0]
|
|
```
|
|
|
|
If you wish to transfer model weights from shared storage, you can use:
|
|
|
|
```python
|
|
weight_update_meta = WeightUpdateMeta.from_disk(config.saver)
|
|
```
|
|
|
|
After a training step is finished, we transfer new weights from actor engine to remote
|
|
inference servers:
|
|
|
|
1. The rollout engine needs to stop sending generation requests to remote servers
|
|
(`rollout.pause()`) before weight update to avoid server-side congestion.
|
|
1. Since we need to invoke weight update on the trainer engine and remote inference
|
|
servers at the same time, in the training script, we asynchronously send requests to
|
|
remote inference servers, and then immediately upload weights on the trainer engine.
|
|
|
|
```python
|
|
rollout.pause()
|
|
if dist.get_rank() == 0:
|
|
future = rollout.update_weights_async(weight_update_meta)
|
|
actor.upload_weights(weight_update_meta)
|
|
if dist.get_rank() == 0:
|
|
future.result()
|
|
dist.barrier(device_ids=[actor.device.index])
|
|
torch.cuda.synchronize()
|
|
rollout.resume()
|
|
actor.set_version(global_step + 1)
|
|
rollout.set_version(global_step + 1)
|
|
```
|
|
|
|
Now a complete GRPO training step in AReaL-lite is done! The core logic of our example
|
|
training script can be summarized as:
|
|
|
|
```python
|
|
data_generator = itertools.cycle(train_dataloader)
|
|
for global_step in range(max_steps):
|
|
if config.async_training:
|
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
|
else:
|
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
|
|
|
logp = actor.compute_logp(batch)
|
|
batch["prox_logp"] = logp
|
|
if ref is not None:
|
|
batch["ref_logp"] = ref.compute_logp(batch)
|
|
log_gpu_stats("ref logp")
|
|
actor.compute_advantages(batch)
|
|
stats = actor.ppo_update(batch)
|
|
actor.step_lr_scheduler()
|
|
|
|
rollout.pause()
|
|
if dist.get_rank() == 0:
|
|
future = rollout.update_weights_async(weight_update_meta)
|
|
actor.upload_weights(weight_update_meta)
|
|
if dist.get_rank() == 0:
|
|
future.result()
|
|
rollout.resume()
|
|
actor.set_version(global_step + 1)
|
|
rollout.set_version(global_step + 1)
|
|
```
|
|
|
|
## Utilities
|
|
|
|
In AReaL-lite, we provide a wide range of utilities for basic functionalities required
|
|
for observing and tuning your experiments.
|
|
|
|
### `Saver` and `Evaluator`
|
|
|
|
`Saver` ([arealite/utils/saver.py](../../arealite/utils/saver.py)) and `Evaluator`
|
|
([arealite/utils/evaluator.py](../../arealite/utils/evaluator.py)) manage the frequency
|
|
to save and evaluate the model with the train engine.
|
|
|
|
In our example, we call `saver.save` and `evaluator.evaluate` after every training step.
|
|
these two methods will automatically check if it is time to save or evaluate the model,
|
|
according to the experiment configuration.
|
|
|
|
### `stats_tracker`
|
|
|
|
`stats_tracker` ([realhf/base/stats_tracker.py](../../realhf/base/stats_tracker.py))
|
|
gathers training statistics across parallel ranks and reduce them.
|
|
|
|
1. **Scalar-type statistics** are recorded by `stats_tracker.scalar(key=value)` and will
|
|
be averaged by the number of scalars with the same key when reduced.
|
|
1. **Tensor-type statistics** require `denominator` and `reduce_type` to decide how to
|
|
reduce statistics under the same key.
|
|
|
|
- `denominator` is a bool tensor that masks the elements in the tensor that we do not
|
|
want to record.
|
|
- `reduce_type` includes average, sum, min and max. By default, the average, min and max
|
|
are all calculated.
|
|
|
|
For example, if we want to record the length of sequences with correct and incorrect
|
|
answers in a training batch:
|
|
|
|
```python
|
|
seqlens = ... # tensor of shape [#seqs,]
|
|
reward_score = ... # tensor of shape [#seqs,]
|
|
|
|
result_denominators = {
|
|
"correct_n_seqs": (reward_score > 0).bool(),
|
|
"incorrect_n_seqs": (reward_score <= 0).bool(),
|
|
}
|
|
# register the denominator
|
|
stats_tracker.denominator(**result_denominators)
|
|
# record the correct and incorrect sequence length
|
|
stats_tracker.stat(
|
|
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
|
|
)
|
|
stats_tracker.stat(
|
|
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
|
|
)
|
|
```
|
|
|
|
`stats_tracker` offers timer context to record time cost of a code block as a scalar.
|
|
And there is also a scope context to manage keys of statistics.
|
|
|
|
```python
|
|
with stats_tracker.record_timing("train_step"):
|
|
# training step
|
|
...
|
|
|
|
with stats_tracker.scope("A"):
|
|
stats_tracker.scalar(c=123) # key="A/c", value=123
|
|
with stats_tracker.scope("B"):
|
|
stats_tracker.scalar(c=234) # key="A/B/c", value=234
|
|
```
|
|
|
|
After recording sufficient data, e.g. after a `train_batch` is finished,
|
|
`stats_tracker.export` is called to aggregate all statistics and dump them into a
|
|
dictionary.
|
|
|
|
```python
|
|
stats = stats_tracker.export()
|
|
```
|
|
|
|
### `StatsLogger`
|
|
|
|
`StatsLogger` ([arealite/utils/stats_logger.py](../../arealite/utils/stats_logger.py))
|
|
logs gathered training data to recorders like `wandb` and `tensorboard` on rank 0. In
|
|
our example script, after finishing a training step,
|
|
`logger.commit(epoch, step, global_step, stats)` is called to record all statistics from
|
|
`stats_tracker` to print them as well as log them into the recorders set by the
|
|
configuration.
|
|
|
|
## Next Steps
|
|
|
|
- [Customize dataset](../customization/dataset.md)
|
|
- [Customize Agentic/RVLR rollout workflows](../customization/agent.md)
|
|
- [Customize algorithms](../customization/algorithm.md)
|