[WIP][feat] Initial support for VLMs, add Qwen2VL SFT test and Qwen2.5VL GRPO test (#188)

* vlm_sft_test

* vlm_sft_test

* .

* .

* Fix unresolved issue in SFTTrainer PR (#139)

* .

* .

* efficient loading

* format

* .

* .

* Fix unresolved issue in SFTTrainer PR (#139)

* .

* .

* efficient loading

* format

* .

* .

* image_process0701

* image_process0701

* image_process0701_2

* image_process0701_2

* image_process0701_3

* image_process0701_3

* .

* .

* .

* .

* .

* .

* imageprocess0702

* imageprocess0702

* image_process0702_2

* image_process0702_2

* image_process0702_3

* image_process0702_3

* image_process0702_4

* image_process0702_4

* image_process0702_5

* image_process0702_5

* image_process0703_1

* image_process0703_1

* 0703_2

* 0703_2

* 0703_3

* 0703_3

* 0703_4

* 0703_4

* 0703_4

* 0703_4

* 0703_5

* 0703_5

* 0703_6

* 0703_6

* 0703_7

* 0703_7

* 0703_8

* 0703_8

* 0703_9

* 0703_9

* 0703_11

* 0703_11

* 0703_12

* 0703_12

* 0703_13

* 0703_13

* 0703_14

* 0703_14

* 0703_15

* 0703_15

* 0703_16

* 0703_16

* 0703-17

* 0703-17

* 0703_18

* 0703_18

* 0703_18

* 0703_18

* 0703_19

* 0703_19

* 0704_1

* 0704_1

* 0704_2

* 0704_2

* 0704_3

* 0704_3

* .

* .

* 0707_1

* 0707_1

* 0707_2

* 0707_2

* 0703_3

* 0703_3

* r

* p

* fix

* fix

* refactor

* 0707_6

* 0707_7

* refactor1

* 0707_undone

* 0708_1

* 0708_2

* 0708_3

* 0708_7

* 0708_4

* 0709_1

* 0709_2

* 0709_3

* 0709_4

* 0709_5

* 0709_

* 0709_6

* 0709_7

* 0709_7

* 0709_8

* 0709_9

* 0710_1

* 0710_2

* 0710_2

* 0710_3

* 0710_3

* 0710_3

* 0710_5

* 0710_4

* merge_2

* merge_3

* 0711_1

* 0711_2

* 0711_3

* 0711_4

* 0711_6

* 0711_7

* 0711_8

* 0711_8

* 0711_9

* 0711_10

* 0711-11

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* 0714_1

* 0714_2

* 0714_3

* 0714_3

* 0714_5

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* 0715_1

* 0715_2

* 0715_2

* 0716_1

* 0716_2

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* 0716_3

* 0716_4

* 0716_4

* 0716_5

* 0717_1

* 0717_3

* 0717_3

* 0717_4

* 0717_5

* 0717_6

* 0717_6

* 0717_6

* 0718_2

* 0718_4

* 0718_5

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* merge_0721

* 0721_1

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* 0721_2

* 0721_3

* 0721_4

* .

* 0721_formal

* 0721_formal

* 0721_merge4

* 0721_merge5

* 0721_6

* 0721_merge6

* 0721_merge7

* 0721_8

* 0722_1

* 0722_2

* 0722_3

* 0722_4

* 0722_4

* 0722_5

* 0722_6

* 0722_7

* 0723_1

* reformatted

* clang-reformatted

* clang-reformatted2

* 0723_1

* 0723_1

* 0723_1

* 0723_merge3

* 0723_4

* 0723_reformatted_5

* 0724_1

* 0724_1

* 0724_merge1

* 0724_merge2

* 0724_merge3

* 0724_merge3

* 0724_merge4

* 0724_merge5

* 0724_merge6

* 0724_merge7

* 0724_4

* 0724-merge8

* 0724_merge8

* 0725_1

* 0725_6

* 0725_7

* 0725_4padded_image

* 0725_9padded_image

* 0725_10padded_image

* 0725

* 0725_12

* 0725_format

---------

Co-authored-by: bowei.fw <bowei.fw@antgroup.com>
Co-authored-by: nuzant <meizhiyu.mzy@antgroup.com>
Co-authored-by: 朱晗 <lichangye.lcy@antgroup.com>
This commit is contained in:
Changye Li 2025-07-28 21:06:33 +08:00 committed by GitHub
parent e2a3579733
commit 7fb6a80e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1546 additions and 178 deletions

View File

@ -611,8 +611,15 @@ class ClusterSpecConfig:
@dataclass
class DatasetConfig:
path: str = field(
default=MISSING,
metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
},
)
type: Optional[str] = field(
default=None, metadata={"help": "Type of implemented dataset"}
default=None,
metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."},
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}
@ -743,7 +750,7 @@ class BaseExperimentConfig:
tokenizer_path: str = field(default="")
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: Optional[DatasetConfig] = field(default=None)
saver: SaverConfig = field(default_factory=SaverConfig)
checkpointer: SaverConfig = field(default_factory=SaverConfig)

View File

@ -8,7 +8,10 @@ import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
from transformers import PreTrainedTokenizerFast
import torch
from gymnasium.core import ActType, ObsType
from PIL.Image import Image as ImageObject
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
from arealite.utils.network import find_free_ports, gethostip
@ -51,6 +54,16 @@ class LLMResponse:
return len(self.output_tokens)
@dataclass
class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
@dataclass
class VLMResponse(LLMResponse):
input_images: List[ImageObject | str] = field(default_factory=list)
@dataclass
class FinetuneSpec:
total_train_epochs: int
@ -216,7 +229,8 @@ class SaveLoadMeta:
path: str
weight_format: str
with_optim: bool
tokenizer: Optional[PreTrainedTokenizerFast]
tokenizer: PreTrainedTokenizerFast | None
processor: AutoProcessor | None
base_model_path: str | None
naive_distributed: bool = False

View File

@ -0,0 +1,47 @@
from typing import Optional
import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset(
path: str,
rank: int,
world_size: int,
type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
**kwargs,
):
if "gsm8k" in path and type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
elif "gsm8k" in path and type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
elif "clevr_count_70k" in path and type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "clevr_count_70k" in path and type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(
path, split, processor, rank, world_size, **kwargs
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)

View File

@ -9,6 +9,8 @@ from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
@ -29,8 +31,8 @@ from arealite.utils.data import (
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model
from realhf.api.core.data_api import load_hf_tokenizer
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine")
@ -44,6 +46,7 @@ class BaseHFEngine(TrainEngine):
self.model: torch.nn.Module
self.optimizer: torch.optim.Optimizer
self.tokenizer: PreTrainedTokenizerFast
self.processor: AutoProcessor | None = None
# huggingface model config
self.model_config: PretrainedConfig
self._version: int = 0
@ -54,6 +57,12 @@ class BaseHFEngine(TrainEngine):
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.is_vision_model = self.model_config.model_type in VALID_VISION_MODELS
self.world_size = int(os.environ["WORLD_SIZE"])
def set_version(self, version: int):
@ -92,32 +101,54 @@ class BaseHFEngine(TrainEngine):
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
if self.is_vision_model:
if dtype == torch.float16:
raise ValueError(
"Vision models do not support float16 dtype. Please use bfloat16."
)
else:
model = AutoModelForCausalLM.from_pretrained(
if self.config.init_from_scratch:
raise ValueError(
"Vision models do not support initialization from scratch. Please use a pretrained model."
)
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
tik = time.perf_counter()
with torch.device("cuda"):
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.disable_dropout:
disable_dropout_in_model(model)
else:
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
@ -218,9 +249,15 @@ class BaseHFEngine(TrainEngine):
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model:
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
@ -230,6 +267,7 @@ class BaseHFEngine(TrainEngine):
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
@ -237,6 +275,7 @@ class BaseHFEngine(TrainEngine):
for mb in mb_list.padded_mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
@ -264,11 +303,13 @@ class BaseHFEngine(TrainEngine):
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
# Scale loss for accumulation

View File

@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec
@ -27,6 +27,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
)
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
@ -77,7 +78,7 @@ class FSDPEngine(BaseHFEngine):
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer)
self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for FSDP
raise NotImplementedError("DCP format saving is not implemented yet. ")
@ -100,7 +101,10 @@ class FSDPEngine(BaseHFEngine):
self.load_optimizer_state(meta.path)
def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
):
"""Save model in HuggingFace format."""
if self.model is None:
@ -119,6 +123,8 @@ class FSDPEngine(BaseHFEngine):
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if processor is not None:
processor.save_pretrained(path)
dist.barrier(device_ids=[self.device.index])
@ -144,13 +150,13 @@ class FSDPEngine(BaseHFEngine):
dist.barrier(device_ids=[self.device.index])
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
self.model_version,
self.get_version(),
)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
@ -247,9 +253,11 @@ class FSDPEngine(BaseHFEngine):
loss.backward()
# NOTE: grad norm clip function is different
grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False

View File

@ -52,7 +52,6 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
logprobs = torch.where(loss_mask, logprobs, 0)
loss = -logprobs.sum() / loss_mask.count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64

View File

@ -24,6 +24,8 @@ from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
RolloutStat,
VLMRequest,
VLMResponse,
WeightUpdateMeta,
)
from arealite.utils.data import concat_padded_tensors
@ -219,7 +221,9 @@ class RemoteSGLangEngine(InferenceEngine):
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest) -> LLMResponse:
async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
gconfig = req.gconfig
@ -237,14 +241,23 @@ class RemoteSGLangEngine(InferenceEngine):
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
"stop_token_ids": stop_token_ids,
}
# NOTE: rid should NOT be passed in payload
payload = {
"input_ids": req.input_ids.copy(),
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
if isinstance(req, VLMRequest):
# VLMRequest has image_data
payload = {
"input_ids": req.input_ids.copy(),
"image_data": req.image_data, # ImageObject or str
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
else:
# NOTE: rid should NOT be passed in payload
payload = {
"input_ids": req.input_ids.copy(),
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
# Make request
start_time = time.perf_counter()
@ -313,15 +326,28 @@ class RemoteSGLangEngine(InferenceEngine):
latency = time.perf_counter() - start_time
return LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
if isinstance(req, VLMRequest):
response = VLMResponse(
input_tokens=req.input_ids,
input_images=req.image_data,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
else:
response = LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
return response
def update_weights(self, meta: WeightUpdateMeta):
for addr in self.addresses:
@ -456,6 +482,7 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
@ -555,6 +582,7 @@ def update_weights_from_distributed(
for addr in addresses
]
)
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
return uvloop.run(_fn())

View File

@ -17,7 +17,6 @@ from realhf.base import gpu_utils, logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("Local Scheduler")
JOB_STATE_TO_PROCESS_STATUS = {
JobState.NOT_FOUND: [],
JobState.PENDING: [psutil.STATUS_PARKED],

View File

@ -63,23 +63,35 @@ def pad_sequences_to_tensors(
) -> TensorDict:
if not sequence_list:
return TensorDict()
max_length = max(len(seq) for item in sequence_list for seq in item.values())
skip_keys = {"pixel_values", "image_grid_thw"}
max_length = max(
len(seq)
for item in sequence_list
for key, seq in item.items()
if key not in skip_keys
)
result = {}
for key in sequence_list[0].keys():
padded = []
if key in skip_keys:
result[key] = [sequence_list[i][key] for i in range(len(sequence_list))]
continue
for item in sequence_list:
x = item[key]
if not torch.is_tensor(x):
x = torch.tensor(x)
padded.append(
torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded_x = torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded.append(padded_x)
result[key] = torch.stack(padded)
attention_mask = [
[1] * len(next(iter(item.values())))
+ [0] * (max_length - len(next(iter(item.values()))))
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
+ [0]
* (
max_length
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
)
for item in sequence_list
]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
@ -121,9 +133,11 @@ def concat_padded_tensors(
assert all("attention_mask" in td for td in tensor_dicts)
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
result = {}
# Process each key
for key in tensor_dicts[0].keys():
tensors_to_concat = []
for tensor_dict in tensor_dicts:
tensor = tensor_dict[key]
# Skip 1D tensors like rewards
@ -131,6 +145,9 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if key == "pixel_values" or key == "image_grid_thw":
tensors_to_concat.append(tensor)
continue
if current_length < max_length:
# Pad tensor to max_length
pad_width = max_length - current_length
@ -139,11 +156,13 @@ def concat_padded_tensors(
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
)
tensor = torch.cat([tensor, padding], dim=1)
tensors_to_concat.append(tensor)
@ -226,12 +245,14 @@ def pack_tensor_dict(data: TensorDict):
total_length = int(cu_seqlens[-1].item())
# Pack tensors
packed_data = {}
packed_data["cu_seqlens"] = cu_seqlens
packed_data["max_seqlen"] = max_seqlen
for key, value in data.items():
if key == "attention_mask":
packed_data["cu_seqlens"] = cu_seqlens
packed_data["max_seqlen"] = max_seqlen
# tensor and of shape [B, S, ...]
elif (
# if key == "attention_mask":
# packed_data["cu_seqlens"] = cu_seqlens
# packed_data["max_seqlen"] = max_seqlen
# # tensor and of shape [B, S, ...]
if (
torch.is_tensor(value)
and value.ndim >= 2
and value.shape[0] == bs
@ -310,6 +331,8 @@ def split_padded_tensor_dict_into_mb_list(
to_split = {}
not_to_split = {}
for key, value in data.items():
if key == "image_grid_thw" or key == "pixel_values":
continue
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
else:
@ -343,6 +366,25 @@ def split_padded_tensor_dict_into_mb_list(
return splitted
to_split = dict_map(to_split, lambda x: _split(x))
if data.get("pixel_values", None) is not None:
pixel_values = data.get("pixel_values", [])
image_grid_thw = data.get("image_grid_thw", [])
# Prepare the pixel_values and image_grid_thw for each group
pixel_values_split = []
image_grid_thw_split = []
for group_index in group_indices:
group_pixel_values = [pixel_values[i] for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
pixel_values_split.append(torch.stack(group_pixel_values))
image_grid_thw_split.append(torch.stack(group_image_grid_thw))
# Pack the split pixel_values and image_grid_thw back into the data
to_split["pixel_values"] = pixel_values_split
to_split["image_grid_thw"] = image_grid_thw_split
mbs = dict_of_list2list_of_dict(to_split)
results = []
@ -447,7 +489,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
new_data = {}
for key, value in data.items():
if (
key not in ["cu_seqlens", "max_seqlen"]
key
not in [
"cu_seqlens",
"max_seqlen",
]
and torch.is_tensor(value)
and value.numel() == total_length
):

View File

@ -46,6 +46,7 @@ def fsdp2_clip_grad_norm_(
grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm

45
arealite/utils/image.py Normal file
View File

@ -0,0 +1,45 @@
import base64
import math
from dataclasses import MISSING
from io import BytesIO
from typing import List
from PIL import Image
from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
if isinstance(images, ImageObject):
images = [images]
byte_images = []
for image in images:
with BytesIO() as buffer:
image.save(buffer, format="PNG")
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
byte_images.append(byte_image)
return byte_images
def pad_images_batch_to_max_size(images):
max_width = max(image.size[0] for image in images)
max_height = max(image.size[1] for image in images)
padded_images = []
for image in images:
width, height = image.size
padding_left = (max_width - width) // 2
padding_top = (max_height - height) // 2
padded_image = Image.new("RGB", (max_width, max_height), (0, 0, 0))
padded_image.paste(image, (padding_left, padding_top))
padded_images.append(padded_image)
return padded_images

View File

@ -1,5 +1,12 @@
import torch
VALID_VISION_MODELS = [
"qwen2_vl",
"qwen2_5_vl",
]
# This registry is used to check if a model is a vision model that we have checked it works with AReaLite. As different vision models vary in their image processing, special tokens and keys, etc. We will add models to this registry as we test them.
# If you want to add a new vision model, please make sure it works with AReaLite.
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:

View File

@ -1,7 +1,7 @@
import getpass
import os
from transformers import PreTrainedTokenizerFast
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import SaverConfig
from arealite.api.engine_api import TrainEngine
@ -56,6 +56,7 @@ class Saver:
global_step: int,
name: str = "default",
tokenizer: PreTrainedTokenizerFast | None = None,
processor: AutoProcessor | None = None,
base_model_path: str | None = None,
):
if not self.freq_ctl.check(
@ -76,6 +77,7 @@ class Saver:
weight_format=weight_format,
with_optim=with_optim,
tokenizer=tokenizer,
processor=processor,
base_model_path=base_model_path,
)
engine.save(meta)

View File

@ -38,6 +38,7 @@ class RLVRWorkflow(RolloutWorkflow):
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,

View File

@ -0,0 +1,121 @@
import asyncio
import os
import uuid
import colorama
import torch
from tensordict import TensorDict
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import VLMRequest
from arealite.utils.data import concat_padded_tensors
from arealite.utils.image import image2base64, pad_images_batch_to_max_size
from arealite.workflow.rlvr import RLVRWorkflow
class VisionRLVRWorkflow(RLVRWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
processor: AutoProcessor,
enable_thinking: bool,
dump_dir: str | None = None,
):
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking, dump_dir)
self.processor = processor
async def arun_episode(self, engine, data):
padded_images = pad_images_batch_to_max_size(data["images"])
processed_input = self.processor(
images=padded_images,
text=data["messages"],
padding=False,
return_tensors="pt",
)
input_ids = processed_input["input_ids"].tolist()[0]
n_samples = self.gconfig.n_samples
byte_images = image2base64(padded_images)
req = VLMRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
image_data=byte_images,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"].unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"].unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)
results.append(TensorDict(res, batch_size=[1]))
if self.dump_dir is not None:
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
# Get the unique identifier for this prompt
qid = None
for key in ["query_id", "id", "qid"]:
qid = data.get(key, None)
if qid is not None:
break
qid = qid or uuid.uuid4().hex
# Dump rollout to file
with open(
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
) as f:
n_samples = self.gconfig.n_samples
for i, (p, c, r, sl) in enumerate(
zip(prompt_strs, completions_strs, rewards, seqlens)
):
info = "\n".join(
[
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
]
)
f.write(info + "\n")
return concat_padded_tensors(results)

View File

@ -1,17 +1,20 @@
# Rollout and Agentic RL
This guide shows you how to create custom rollout behaviors for RL training by building
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
problems until it finds the correct answer.
This guide demonstrates how to customize rollout behavior for PPO training by
implementing a multi-turn math agent that uses end-to-end reinforcement learning. Our
example agent will continuously try to solve a math problem until it reaches the correct
answer.
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
## Approach: Using AReaLite (Recommended)
## Step 1: Define Your Workflow
The complete implementation is placed at `arealite/workflow/multi_turn.py`.
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
classes that might constrain your agent's capabilities, AReaLite captures all rollout
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
behavior however you need.
### Step 1: Define Your Workflow
AReaLite takes a flexible approach to agent definition. Rather than using rigid `Agent`
classes that might limit your agentic capabilities, AReaLite captures all rollout
behavior in a `RolloutWorkflow` class. This design gives you complete freedom to
customize your agent's behavior.
```python
# arealite/api/workflow_api.py
@ -26,8 +29,8 @@ class RolloutWorkflow:
raise NotImplementedError()
```
The workflow exposes an `arun_episode` method that runs and collects data from a single
episode. This method takes two key arguments:
The workflow exposes a single `arun_episode` method that runs and collects data from a
single episode. This method takes two key arguments:
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
inputs
@ -36,19 +39,14 @@ episode. This method takes two key arguments:
Within this method, you have complete control over how your agent and environment
interact.
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
> generated from that prompt—it's not batched. However, you can generate multiple
> trajectories from a single prompt (for example, with GRPO or tree search).
### Setting Up the Multi-Turn Math Workflow
#### Setting Up the Multi-Turn Math Workflow
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
the `__init__` method to set up what we need during rollout:
the `__init__` method to capture the utilities we need during rollout:
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
> whatever arguments you need to construct your workflow. If you want to use tools, pass
> the corresponding environment here so your agent can call it in the `arun_episode`
> method.
> **Note**: You have complete flexibility in defining the `__init__` method. Pass any
> arguments needed to construct your workflow. If you want to use tools, pass the
> corresponding environment here so your agent can call it in the `arun_episode` method.
```python
class MultiTurnWorkflow(RolloutWorkflow):
@ -68,7 +66,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
self.turn_discount = turn_discount
```
### Implementing the Episode Logic
#### Implementing the Episode Logic
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
and converting it into an `LLMRequest` object for the inference engine:
@ -102,20 +100,20 @@ class MultiTurnWorkflow(RolloutWorkflow):
# ... continue processing ...
```
> **Note**: This example uses the "messages" key from the prompt data to get
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
> **Note**: This example accesses the "messages" key from the prompt data to get
> OpenAI-compatible messages. This isn't mandatory—the key and prompt format depend
> entirely on your implementation. For instance, if your dataset stores prompt strings
> in a "prompt" column, you could get input token IDs with
> in a "prompt" column, you could get input token IDs via
> `self.tokenizer.encode(data["prompt"])`.
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
> will reuse the LLM inference server's KV caches for better efficiency.
> will reuse the LLM inference server's KV caches for efficiency.
### Handling Multi-Turn Conversations
#### Handling Multi-Turn Conversations
Next, we'll check if the current answer is correct using our `reward_fn`. This function
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
apply a discount, add feedback to the conversation, and let the model try again:
Next, we'll evaluate whether the current answer is correct using our `reward_fn`. This
function should return 1 for correct answers and 0 otherwise. When the answer is wrong,
we'll apply a discount, add feedback to the conversation, and let the model try again:
```python
class MultiTurnWorkflow(RolloutWorkflow):
@ -150,10 +148,10 @@ class MultiTurnWorkflow(RolloutWorkflow):
discount *= self.turn_discount
```
### Reward Function Signature
#### Reward Function Signature
To make it easier to switch between different reward functions, we recommend following
this signature:
For convenience when switching between different reward functions, we recommend
following this pre-defined signature:
```python
def reward_fn(
@ -180,10 +178,10 @@ def reward_fn(
"""
```
While this signature is convenient, you're not restricted to it in custom
workflows—modify as needed for your specific use case.
While this signature is convenient, there are no strict restrictions on reward functions
in custom workflows—modify them as needed for your specific use case.
### Collecting Training Data
#### Collecting Training Data
Finally, let's complete the implementation by collecting trajectories in the
`TensorDict` format:
@ -218,6 +216,191 @@ class MultiTurnWorkflow(RolloutWorkflow):
return concat_padded_tensors([res])
```
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
> to automatically batch multiple trajectories for the training engine. Since this
> example returns a single trajectory, we use `unsqueeze(0)` to create a size-1 batch.
> **Note**: There are no restrictions on the keys in your `TensorDict`—different
> algorithms require different keys. This example targets the GRPO algorithm, so we
> include `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for
> computing importance ratios).
### Step 2: Training with Your Custom Workflow
Using your custom workflow is straightforward—just construct it in your training script
and pass it to the `rollout_batch` or `prepare_batch` method:
```python
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
# ... continue with training loop ...
```
That's it! Your custom multi-turn math agent is now ready to train with reinforcement
learning. The workflow will automatically handle the multi-turn conversations, reward
computation, and data collection needed for effective RL training.
## Alternative Approach: Using the Legacy Version (Not Recommended)
While we strongly recommend using AReaLite for new projects, you might encounter legacy
code that uses the older Agent-based approach. Here's how it works for reference, though
we suggest migrating to the workflow-based system when possible.
### Step 1: Define Your Agent Class
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
requires a single method: `collect_trajectory`.
```python
class MathMultiTurnAgent(Agent):
async def collect_trajectory(
self,
reward_fn,
gconfig: GenerationHyperparameters, # aka sampling_params
tokenizer: PreTrainedTokenizerFast,
max_turns: int,
turn_discount: float,
):
# Implementation goes here
...
```
### Step 2: Implement the Trajectory Collection Logic
The `collect_trajectory` method takes a task prompt, an environment, and two
communication queues. Within this method, you control the data flow between your agent
and the inference engine using these queues:
- **obs_queue**: Send observations (token IDs and generation config) to the inference
engine
- **act_queue**: Receive actions (generated responses) from the inference engine
Here's how the multi-turn conversation works:
```python
for turn in range(self.num_turns):
# Send the current state to the inference engine
await obs_queue.put((qid, token_ids, self.gconfig))
# Get the generated response
act: BundledGenerationOutputs = await act_queue.get()
# Evaluate the response through the environment
success, rewards = await env.step((qid, answers))
# ... process results ...
```
#### Environment Integration
The environment follows a
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
`step` methods, but uses asynchronous implementations to prevent blocking across
different environment instances.
For math problems, the environment is typically stateless and acts as a wrapper around
your reward function:
```python
class MathCodeSingleStepEnv(EnvironmentService):
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
# ... setup code ...
# Run reward computation asynchronously
format_rewards = await asyncio.to_thread(
math_verify_call,
answers,
# ... other parameters ...
)
return None, format_rewards, True, False, {}
```
#### Handling Multi-Turn Feedback
After receiving the reward from `env.step`, check if the answer is correct. If not,
provide feedback and continue to the next turn:
```python
for turn in range(self.num_turns):
# ... generation and evaluation code ...
# Provide feedback based on the result
if success[0]:
feedback = "Congratulations! You are correct!"
else:
feedback = "Unfortunately your answer is wrong. Let's try again."
# Format feedback as a user message
feedback = "\n" + self.tokenizer.apply_chat_template(
[{"content": feedback, "role": "user"}],
add_generation_prompt=True,
tokenize=False,
)
# Add feedback tokens to the conversation
feedback_tokens = self.tokenizer(feedback)["input_ids"]
token_ids.extend(feedback_tokens)
```
### Step 3: Register and Configure Your Agent
First, register your agent implementation:
```python
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
# ... episode logic above ...
while reward == 0 and t < self.max_turns:
# ... generation and evaluation ...
# Collect trajectory data
input_len = len(resp.input_tokens) - len(seq)
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
versions += [-1] * input_len + resp.output_versions
# Package results
res = dict(
input_ids=torch.tensor(seq),
logprobs=torch.tensor(logprobs),
loss_mask=torch.tensor(loss_mask),
versions=torch.tensor(versions),
rewards=torch.tensor(float(reward * discount)),
attention_mask=torch.ones(len(seq), dtype=torch.bool),
)
res = {k: v.unsqueeze(0) for k, v in res.items()}
return concat_padded_tensors([res])
```
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
> to automatically batch multiple trajectories for training. Since this example returns
@ -234,34 +417,62 @@ Using your custom workflow is straightforward—just create it in your training
pass it to the `rollout_batch` or `prepare_batch` method:
```python
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
# ... continue with training loop ...
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent
```
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
training. The workflow will automatically handle the multi-turn conversations, reward
computation, and data collection needed for effective RL training.
Then update your experiment configuration in
`realhf/experiments/async_exp/async_math_ppo.py`:
```python
@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
# Add any new CLI arguments your agent needs
my_param: float = 1.0
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction(
"math-multi-turn", # Your registered agent name
args=dict(
# Pass any arguments needed for your __init__ method
my_param=self.my_param,
# ... other configuration ...
),
)
@property
def env(self) -> EnvServiceAbstraction:
# Update to use your custom environment if needed
return EnvServiceAbstraction(
"math-code-single-step",
args=dict(dataset_path=self.dataset.path)
)
```
### Step 4: Run Training
Follow the standard training procedure outlined in the
[quickstart guide](../tutorial/quickstart.md). Launch your experiment with:
```bash
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
```
### Training Results
Here's an example of the training reward curve from our multi-turn math agent:
![Multi-turn Training Rewards](multiturn_reward.png)
The agent successfully learns to solve math problems with improved accuracy over time,
demonstrating the effectiveness of the multi-turn approach.
______________________________________________________________________
**Note**: While this legacy approach works, we strongly recommend using the AReaLite
workflow system for new projects. It provides better flexibility, cleaner abstractions,
and easier maintenance. Consider migrating existing legacy agents to the workflow-based
approach when possible.
Happy coding!

View File

@ -0,0 +1,261 @@
import os
import re
import sys
import torch
import torch.distributed as dist
import wandb
from torch.utils.data import Subset
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return ""
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
return 1
return 0
def main(args):
wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
processor=processor,
)
train_size = len(train_dataset)
subset_size = int(1.0 * train_size)
random_indices = torch.randperm(train_size).tolist()[:subset_size]
subset_train_dataset = Subset(train_dataset, random_indices)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
subset_train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.valid_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
eval_rollout = RemoteSGLangEngine(config.rollout)
eval_rollout.initialize(None, ft_spec)
# NOTE: set a large version such that eval does not have any offpolicyness control
eval_rollout.set_version(int(1e12))
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [WeightUpdateMeta.from_disk(config.saver)]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = VisionRLVRWorkflow(
reward_fn=clevr_count_70k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
processor=processor,
enable_thinking=False,
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
log_gpu_stats("recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
log_gpu_stats("compute advantages")
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"final_reward": stats[0]["grpo_actor/final_reward/avg"]})
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
rollout.pause()
if dist.get_rank() == 0:
future = rollout.update_weights(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
future.result()
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
with stats_tracker.record_timing("eval"):
def evaluate_fn():
rollout.pause()
cnt = 0
for data in valid_dataloader:
for item in data:
eval_rollout.submit(item, workflow)
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
wandb.log({"eval_reward": rewards.mean().item()})
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
rewards.shape[0],
device=rewards.device,
dtype=torch.bool,
)
)
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
rollout.resume()
evaluator.evaluate(
evaluate_fn,
epoch,
step,
global_step,
)
logger.commit(epoch, step, global_step, stats)
logger.close()
eval_rollout.destroy()
rollout.destroy()
if ref is not None:
ref.destroy()
actor.destroy()
wandb.finish()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -0,0 +1,124 @@
import os
import sys
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
def main_sft():
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
config: SFTConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
tokenizer=tokenizer,
processor=processor,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
tokenizer=tokenizer,
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=pad_sequences_to_tensors,
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,
collate_fn=pad_sequences_to_tensors,
drop_last=config.valid_dataset.drop_last,
)
# Initialize engine
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
engine = FSDPLMEngine(config=config.model)
engine.initialize(None, ft_spec)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
for epoch in range(total_epochs):
for step, data in enumerate(train_dataloader):
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("sft"),
):
stats = engine.train_lm(data)
engine.step_lr_scheduler()
stats_tracker.scalar(**stats)
with stats_tracker.record_timing("save"):
saver.save(engine, epoch, step, global_step)
with stats_tracker.record_timing("eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
def evaluate_fn():
with stats_tracker.scope("sft-eval"):
for data in valid_dataloader:
engine.evaluate_lm(data)
evaluator.evaluate(
evaluate_fn,
epoch,
step,
global_step,
)
logger.commit(
epoch,
step,
global_step,
stats_tracker.export(reduce_group=engine.parallelism_group),
)
global_step += 1
logger.close()
engine.destroy()
if __name__ == "__main__":
main_sft()

View File

@ -0,0 +1,139 @@
experiment_name: clevr_count_70k-grpo
trial_name: trial1
seed: 1
total_train_epochs: 3
tokenizer_path: ${actor.path}
async_training: true
cluster:
n_nodes: 1
n_gpus_per_node: 8
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/clevr_count_70k-grpo
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
allocation_mode: sglang.d1p1t1+d7p1t1
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 256
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: false
gconfig:
n_samples: 4
min_new_tokens: 0
max_new_tokens: 512
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: false
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 10240
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.8
# datasets
train_dataset:
batch_size: 32
shuffle: true
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: rl
valid_dataset:
batch_size: 32
shuffle: true
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: rl
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}

View File

@ -0,0 +1,81 @@
experiment_name: clevr_count_70k-sft
trial_name: trial0
cluster:
n_nodes: 1
n_gpus_per_node: 1
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/clevr_count_70k-sft
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
seed: 1
total_train_epochs: 3
tokenizer_path: ${model.path}
model:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen2-VL-7B
dtype: bfloat16
init_from_scratch: false
gradient_checkpointing: false
mb_spec:
max_tokens_per_mb: 4096
optimizer:
type: adam
lr: 2e-5
weight_decay: 0.05
beta1: 0.9
beta2: 0.95
eps: 1e-5
lr_scheduler_type: cosine
gradient_clipping: 1.0
backend: fsdp
train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: sft
valid_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: sft
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: 1
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}

View File

@ -2,9 +2,9 @@ experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
cluster:
fileroot: /tmp/arealite/experiments
n_nodes: 1
n_gpus_per_node: 8
fileroot: /tmp/arealite/experiments
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
@ -90,11 +90,17 @@ train_dataset:
batch_size: 256
shuffle: true
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: rl
valid_dataset:
batch_size: 256
shuffle: true
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: rl
# Utilities
saver:
@ -125,5 +131,4 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -1,8 +1,9 @@
experiment_name: gsm8k-sft
trial_name: trial0
n_nodes: 1
n_gpus_per_node: 8
cluster:
n_nodes: 1
n_gpus_per_node: 8
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
@ -34,11 +35,17 @@ train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: sft
valid_dataset:
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: sft
# Utilities
saver:
@ -69,5 +76,3 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -0,0 +1,125 @@
import math
from io import BytesIO
from typing import Any, Dict, Optional, Union
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from PIL.Image import Image as ImageObject
def convert_image(
image: Union[Dict[str, Any], ImageObject, str],
max_pixels: Optional[int],
) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
with BytesIO() as output:
image.save(output, format="JPEG")
return output.getvalue()
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
"""
"clevr_count_70k": {
"image_key": "images",
"question_key": "problem",
"answer_key": "answer"
},
"""
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer
def process_example(example, idx):
# Add query_id column
images = example["images"]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = (
example["problem"].replace("<image>", image_token).replace("different", "")
)
processed_images = []
for image in images:
processed_images.append(convert_image(image, 336 * 336))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
return example
dataset = dataset.map(
lambda example, idx: process_example(example, idx),
with_indices=True,
)
def _process(example):
text = example["seq"]
processed_input = processor(
text=[text],
images=example["images"],
padding=False,
return_tensors="pt",
return_length=True,
return_attention_mask=False,
)
example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
answer_token
)
example["loss_mask"] = loss_mask
return example
dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
processed_images = [
convert_image(image, 336 * 336) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
system_prompt = {
"role": "system",
"content": (
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
),
}
messages = [
{
"role": "user",
"content": sample["problem"]
.replace("<image>", image_token)
.replace("different", ""),
}
]
messages.insert(0, system_prompt)
messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images}
dataset = dataset.map(process).remove_columns(["problem"])
return dataset

View File

@ -0,0 +1,30 @@
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_rl_dataset(path, split, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset

View File

@ -3,12 +3,11 @@ import sys
import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -22,21 +21,6 @@ from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("GSM8K grpo")
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
from realhf.impl.dataset.math_parser import process_results
@ -52,10 +36,26 @@ def main(args):
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
tokenizer=tokenizer,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
@ -63,7 +63,7 @@ def main(args):
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", rank, world_size),
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,

View File

@ -7,6 +7,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
@ -16,25 +17,6 @@ from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_sft_dataset(dataset, tokenizer)
def main(args):
config, _ = load_expr_config(args, SFTConfig)
config: SFTConfig
@ -43,9 +25,26 @@ def main(args):
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
tokenizer=tokenizer,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", tokenizer, rank, world_size),
train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
@ -53,7 +52,7 @@ def main(args):
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", tokenizer, rank, world_size),
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,

View File

@ -31,7 +31,7 @@ dependencies = [
"huggingface_hub",
"datasets",
"accelerate",
"transformers==4.53.0",
"transformers==4.53.1",
# Scientific computing
"scipy",

View File

@ -69,6 +69,27 @@ def load_hf_tokenizer(
return tokenizer
@lru_cache(maxsize=8)
def load_hf_processor_and_tokenizer(
model_name_or_path: str,
fast_tokenizer=True,
padding_side: Optional[str] = None,
) -> Tuple[transformers.AutoProcessor, transformers.PreTrainedTokenizerFast]:
"""Load a tokenizer and processor from Hugging Face."""
tokenizer = load_hf_tokenizer(model_name_or_path, fast_tokenizer, padding_side)
try:
processor = transformers.AutoProcessor.from_pretrained(
model_name_or_path, trust_remote_code=True, force_download=True
)
except Exception:
processor = None
logger.warning(
f"Failed to load processor for {model_name_or_path}. "
"Using tokenizer only. This may cause issues with some models."
)
return processor, tokenizer
@pdclasses.dataclass
class SequenceSplitSpec:
partitions: Optional[List[Tuple[int, int]]] = None

View File

@ -35,6 +35,7 @@ def load_hf_or_local_file(path: str) -> str:
=>
/root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl
"""
path = str(path)
if path.startswith("hf://") or path.startswith("hf-dataset://"):
repo_type = "dataset" if path.startswith("hf-dataset://") else "model"
hf_path = path.strip().split("://")[1]