mirror of https://github.com/inclusionAI/AReaL
[WIP][feat] Initial support for VLMs, add Qwen2VL SFT test and Qwen2.5VL GRPO test (#188)
* vlm_sft_test * vlm_sft_test * . * . * Fix unresolved issue in SFTTrainer PR (#139) * . * . * efficient loading * format * . * . * Fix unresolved issue in SFTTrainer PR (#139) * . * . * efficient loading * format * . * . * image_process0701 * image_process0701 * image_process0701_2 * image_process0701_2 * image_process0701_3 * image_process0701_3 * . * . * . * . * . * . * imageprocess0702 * imageprocess0702 * image_process0702_2 * image_process0702_2 * image_process0702_3 * image_process0702_3 * image_process0702_4 * image_process0702_4 * image_process0702_5 * image_process0702_5 * image_process0703_1 * image_process0703_1 * 0703_2 * 0703_2 * 0703_3 * 0703_3 * 0703_4 * 0703_4 * 0703_4 * 0703_4 * 0703_5 * 0703_5 * 0703_6 * 0703_6 * 0703_7 * 0703_7 * 0703_8 * 0703_8 * 0703_9 * 0703_9 * 0703_11 * 0703_11 * 0703_12 * 0703_12 * 0703_13 * 0703_13 * 0703_14 * 0703_14 * 0703_15 * 0703_15 * 0703_16 * 0703_16 * 0703-17 * 0703-17 * 0703_18 * 0703_18 * 0703_18 * 0703_18 * 0703_19 * 0703_19 * 0704_1 * 0704_1 * 0704_2 * 0704_2 * 0704_3 * 0704_3 * . * . * 0707_1 * 0707_1 * 0707_2 * 0707_2 * 0703_3 * 0703_3 * r * p * fix * fix * refactor * 0707_6 * 0707_7 * refactor1 * 0707_undone * 0708_1 * 0708_2 * 0708_3 * 0708_7 * 0708_4 * 0709_1 * 0709_2 * 0709_3 * 0709_4 * 0709_5 * 0709_ * 0709_6 * 0709_7 * 0709_7 * 0709_8 * 0709_9 * 0710_1 * 0710_2 * 0710_2 * 0710_3 * 0710_3 * 0710_3 * 0710_5 * 0710_4 * merge_2 * merge_3 * 0711_1 * 0711_2 * 0711_3 * 0711_4 * 0711_6 * 0711_7 * 0711_8 * 0711_8 * 0711_9 * 0711_10 * 0711-11 * PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * 0714_1 * 0714_2 * 0714_3 * 0714_3 * 0714_5 * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * 0715_1 * 0715_2 * 0715_2 * 0716_1 * 0716_2 * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * 0716_3 * 0716_4 * 0716_4 * 0716_5 * 0717_1 * 0717_3 * 0717_3 * 0717_4 * 0717_5 * 0717_6 * 0717_6 * 0717_6 * 0718_2 * 0718_4 * 0718_5 * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * merge_0721 * 0721_1 * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * 0721_2 * 0721_3 * 0721_4 * . * 0721_formal * 0721_formal * 0721_merge4 * 0721_merge5 * 0721_6 * 0721_merge6 * 0721_merge7 * 0721_8 * 0722_1 * 0722_2 * 0722_3 * 0722_4 * 0722_4 * 0722_5 * 0722_6 * 0722_7 * 0723_1 * reformatted * clang-reformatted * clang-reformatted2 * 0723_1 * 0723_1 * 0723_1 * 0723_merge3 * 0723_4 * 0723_reformatted_5 * 0724_1 * 0724_1 * 0724_merge1 * 0724_merge2 * 0724_merge3 * 0724_merge3 * 0724_merge4 * 0724_merge5 * 0724_merge6 * 0724_merge7 * 0724_4 * 0724-merge8 * 0724_merge8 * 0725_1 * 0725_6 * 0725_7 * 0725_4padded_image * 0725_9padded_image * 0725_10padded_image * 0725 * 0725_12 * 0725_format --------- Co-authored-by: bowei.fw <bowei.fw@antgroup.com> Co-authored-by: nuzant <meizhiyu.mzy@antgroup.com> Co-authored-by: 朱晗 <lichangye.lcy@antgroup.com>
This commit is contained in:
parent
e2a3579733
commit
7fb6a80e48
|
@ -611,8 +611,15 @@ class ClusterSpecConfig:
|
|||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
path: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
|
||||
},
|
||||
)
|
||||
type: Optional[str] = field(
|
||||
default=None, metadata={"help": "Type of implemented dataset"}
|
||||
default=None,
|
||||
metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=1, metadata={"help": "Batch size of the dataloader"}
|
||||
|
@ -743,7 +750,7 @@ class BaseExperimentConfig:
|
|||
tokenizer_path: str = field(default="")
|
||||
|
||||
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
valid_dataset: Optional[DatasetConfig] = field(default=None)
|
||||
|
||||
saver: SaverConfig = field(default_factory=SaverConfig)
|
||||
checkpointer: SaverConfig = field(default_factory=SaverConfig)
|
||||
|
|
|
@ -8,7 +8,10 @@ import uuid
|
|||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
import torch
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
|
@ -51,6 +54,16 @@ class LLMResponse:
|
|||
return len(self.output_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMRequest(LLMRequest):
|
||||
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResponse(LLMResponse):
|
||||
input_images: List[ImageObject | str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuneSpec:
|
||||
total_train_epochs: int
|
||||
|
@ -216,7 +229,8 @@ class SaveLoadMeta:
|
|||
path: str
|
||||
weight_format: str
|
||||
with_optim: bool
|
||||
tokenizer: Optional[PreTrainedTokenizerFast]
|
||||
tokenizer: PreTrainedTokenizerFast | None
|
||||
processor: AutoProcessor | None
|
||||
base_model_path: str | None
|
||||
naive_distributed: bool = False
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||
|
||||
|
||||
def get_custom_dataset(
|
||||
path: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
type: str = "sft",
|
||||
split: Optional[str] = None,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
processor: Optional[transformers.AutoProcessor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if "gsm8k" in path and type == "sft":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
|
||||
elif "gsm8k" in path and type == "rl":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
|
||||
elif "clevr_count_70k" in path and type == "sft":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_sft_dataset(
|
||||
path, split, processor, rank, world_size, **kwargs
|
||||
)
|
||||
elif "clevr_count_70k" in path and type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_rl_dataset(
|
||||
path, split, processor, rank, world_size, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {path} with split {split} and training type {type} is not supported. "
|
||||
f"Supported datasets are: {VALID_DATASETS}. "
|
||||
)
|
|
@ -9,6 +9,8 @@ from tensordict import TensorDict
|
|||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
|
@ -29,8 +31,8 @@ from arealite.utils.data import (
|
|||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
|
||||
from realhf.base import constants, logging
|
||||
|
||||
logger = logging.getLogger("Base HF Engine")
|
||||
|
@ -44,6 +46,7 @@ class BaseHFEngine(TrainEngine):
|
|||
self.model: torch.nn.Module
|
||||
self.optimizer: torch.optim.Optimizer
|
||||
self.tokenizer: PreTrainedTokenizerFast
|
||||
self.processor: AutoProcessor | None = None
|
||||
# huggingface model config
|
||||
self.model_config: PretrainedConfig
|
||||
self._version: int = 0
|
||||
|
@ -54,6 +57,12 @@ class BaseHFEngine(TrainEngine):
|
|||
self._parallelism_group: dist.ProcessGroup
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.is_vision_model = self.model_config.model_type in VALID_VISION_MODELS
|
||||
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def set_version(self, version: int):
|
||||
|
@ -92,32 +101,54 @@ class BaseHFEngine(TrainEngine):
|
|||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
|
||||
if self.is_vision_model:
|
||||
if dtype == torch.float16:
|
||||
raise ValueError(
|
||||
"Vision models do not support float16 dtype. Please use bfloat16."
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if self.config.init_from_scratch:
|
||||
raise ValueError(
|
||||
"Vision models do not support initialization from scratch. Please use a pretrained model."
|
||||
)
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||
self.config.path
|
||||
)
|
||||
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
else:
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
|
@ -218,9 +249,15 @@ class BaseHFEngine(TrainEngine):
|
|||
|
||||
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if self.is_vision_model:
|
||||
assert (
|
||||
"pixel_values" in input_ and "image_grid_thw" in input_
|
||||
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
|
@ -230,6 +267,7 @@ class BaseHFEngine(TrainEngine):
|
|||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
|
@ -237,6 +275,7 @@ class BaseHFEngine(TrainEngine):
|
|||
for mb in mb_list.padded_mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
@ -264,11 +303,13 @@ class BaseHFEngine(TrainEngine):
|
|||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
|
||||
# Scale loss for accumulation
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec
|
||||
|
@ -27,6 +27,7 @@ from arealite.utils.fsdp import (
|
|||
fsdp2_load_full_state_dict,
|
||||
)
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
@ -77,7 +78,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._save_model_to_hf(meta.path, meta.tokenizer)
|
||||
self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor)
|
||||
elif meta.weight_format == "dcp":
|
||||
# TODO: implement DCP save/load for FSDP
|
||||
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
||||
|
@ -100,7 +101,10 @@ class FSDPEngine(BaseHFEngine):
|
|||
self.load_optimizer_state(meta.path)
|
||||
|
||||
def _save_model_to_hf(
|
||||
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[PreTrainedTokenizerFast],
|
||||
processor: Optional[AutoProcessor],
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
|
@ -119,6 +123,8 @@ class FSDPEngine(BaseHFEngine):
|
|||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
if processor is not None:
|
||||
processor.save_pretrained(path)
|
||||
|
||||
dist.barrier(device_ids=[self.device.index])
|
||||
|
||||
|
@ -144,13 +150,13 @@ class FSDPEngine(BaseHFEngine):
|
|||
dist.barrier(device_ids=[self.device.index])
|
||||
torch.cuda.synchronize()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer)
|
||||
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
||||
# dist.barrier() are called when _save_model_to_hf finished
|
||||
if dist.get_rank() == 0:
|
||||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
self.model_version,
|
||||
self.get_version(),
|
||||
)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
|
@ -247,9 +253,11 @@ class FSDPEngine(BaseHFEngine):
|
|||
loss.backward()
|
||||
|
||||
# NOTE: grad norm clip function is different
|
||||
|
||||
grad_norm = fsdp2_clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
||||
)
|
||||
|
||||
if not torch.isfinite(grad_norm):
|
||||
self.optimizer.zero_grad()
|
||||
update_successful = False
|
||||
|
|
|
@ -52,7 +52,6 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
|||
logprobs = torch.where(loss_mask, logprobs, 0)
|
||||
|
||||
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
|
|
|
@ -24,6 +24,8 @@ from arealite.api.io_struct import (
|
|||
LLMRequest,
|
||||
LLMResponse,
|
||||
RolloutStat,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
|
@ -219,7 +221,9 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
async def agenerate(
|
||||
self, req: LLMRequest | VLMRequest
|
||||
) -> LLMResponse | VLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
|
@ -237,14 +241,23 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
||||
"stop_token_ids": stop_token_ids,
|
||||
}
|
||||
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
if isinstance(req, VLMRequest):
|
||||
# VLMRequest has image_data
|
||||
payload = {
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"image_data": req.image_data, # ImageObject or str
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
else:
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
|
@ -313,15 +326,28 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
if isinstance(req, VLMRequest):
|
||||
response = VLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
input_images=req.image_data,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
else:
|
||||
response = LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
return response
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
for addr in self.addresses:
|
||||
|
@ -456,6 +482,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
|
@ -555,6 +582,7 @@ def update_weights_from_distributed(
|
|||
for addr in addresses
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
||||
|
||||
return uvloop.run(_fn())
|
||||
|
|
|
@ -17,7 +17,6 @@ from realhf.base import gpu_utils, logging, name_resolve, names
|
|||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Local Scheduler")
|
||||
|
||||
JOB_STATE_TO_PROCESS_STATUS = {
|
||||
JobState.NOT_FOUND: [],
|
||||
JobState.PENDING: [psutil.STATUS_PARKED],
|
||||
|
|
|
@ -63,23 +63,35 @@ def pad_sequences_to_tensors(
|
|||
) -> TensorDict:
|
||||
if not sequence_list:
|
||||
return TensorDict()
|
||||
max_length = max(len(seq) for item in sequence_list for seq in item.values())
|
||||
skip_keys = {"pixel_values", "image_grid_thw"}
|
||||
max_length = max(
|
||||
len(seq)
|
||||
for item in sequence_list
|
||||
for key, seq in item.items()
|
||||
if key not in skip_keys
|
||||
)
|
||||
result = {}
|
||||
for key in sequence_list[0].keys():
|
||||
padded = []
|
||||
if key in skip_keys:
|
||||
result[key] = [sequence_list[i][key] for i in range(len(sequence_list))]
|
||||
continue
|
||||
for item in sequence_list:
|
||||
x = item[key]
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.tensor(x)
|
||||
padded.append(
|
||||
torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded_x = torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded.append(padded_x)
|
||||
result[key] = torch.stack(padded)
|
||||
attention_mask = [
|
||||
[1] * len(next(iter(item.values())))
|
||||
+ [0] * (max_length - len(next(iter(item.values()))))
|
||||
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
+ [0]
|
||||
* (
|
||||
max_length
|
||||
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
)
|
||||
for item in sequence_list
|
||||
]
|
||||
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
||||
|
@ -121,9 +133,11 @@ def concat_padded_tensors(
|
|||
assert all("attention_mask" in td for td in tensor_dicts)
|
||||
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
||||
result = {}
|
||||
|
||||
# Process each key
|
||||
for key in tensor_dicts[0].keys():
|
||||
tensors_to_concat = []
|
||||
|
||||
for tensor_dict in tensor_dicts:
|
||||
tensor = tensor_dict[key]
|
||||
# Skip 1D tensors like rewards
|
||||
|
@ -131,6 +145,9 @@ def concat_padded_tensors(
|
|||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if key == "pixel_values" or key == "image_grid_thw":
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
if current_length < max_length:
|
||||
# Pad tensor to max_length
|
||||
pad_width = max_length - current_length
|
||||
|
@ -139,11 +156,13 @@ def concat_padded_tensors(
|
|||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
|
||||
)
|
||||
|
||||
tensor = torch.cat([tensor, padding], dim=1)
|
||||
tensors_to_concat.append(tensor)
|
||||
|
||||
|
@ -226,12 +245,14 @@ def pack_tensor_dict(data: TensorDict):
|
|||
total_length = int(cu_seqlens[-1].item())
|
||||
# Pack tensors
|
||||
packed_data = {}
|
||||
packed_data["cu_seqlens"] = cu_seqlens
|
||||
packed_data["max_seqlen"] = max_seqlen
|
||||
for key, value in data.items():
|
||||
if key == "attention_mask":
|
||||
packed_data["cu_seqlens"] = cu_seqlens
|
||||
packed_data["max_seqlen"] = max_seqlen
|
||||
# tensor and of shape [B, S, ...]
|
||||
elif (
|
||||
# if key == "attention_mask":
|
||||
# packed_data["cu_seqlens"] = cu_seqlens
|
||||
# packed_data["max_seqlen"] = max_seqlen
|
||||
# # tensor and of shape [B, S, ...]
|
||||
if (
|
||||
torch.is_tensor(value)
|
||||
and value.ndim >= 2
|
||||
and value.shape[0] == bs
|
||||
|
@ -310,6 +331,8 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key == "image_grid_thw" or key == "pixel_values":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
else:
|
||||
|
@ -343,6 +366,25 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
return splitted
|
||||
|
||||
to_split = dict_map(to_split, lambda x: _split(x))
|
||||
if data.get("pixel_values", None) is not None:
|
||||
pixel_values = data.get("pixel_values", [])
|
||||
image_grid_thw = data.get("image_grid_thw", [])
|
||||
|
||||
# Prepare the pixel_values and image_grid_thw for each group
|
||||
pixel_values_split = []
|
||||
image_grid_thw_split = []
|
||||
|
||||
for group_index in group_indices:
|
||||
group_pixel_values = [pixel_values[i] for i in group_index]
|
||||
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
|
||||
|
||||
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
|
||||
pixel_values_split.append(torch.stack(group_pixel_values))
|
||||
image_grid_thw_split.append(torch.stack(group_image_grid_thw))
|
||||
|
||||
# Pack the split pixel_values and image_grid_thw back into the data
|
||||
to_split["pixel_values"] = pixel_values_split
|
||||
to_split["image_grid_thw"] = image_grid_thw_split
|
||||
mbs = dict_of_list2list_of_dict(to_split)
|
||||
|
||||
results = []
|
||||
|
@ -447,7 +489,11 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
|||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key not in ["cu_seqlens", "max_seqlen"]
|
||||
key
|
||||
not in [
|
||||
"cu_seqlens",
|
||||
"max_seqlen",
|
||||
]
|
||||
and torch.is_tensor(value)
|
||||
and value.numel() == total_length
|
||||
):
|
||||
|
|
|
@ -46,6 +46,7 @@ def fsdp2_clip_grad_norm_(
|
|||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import base64
|
||||
import math
|
||||
from dataclasses import MISSING
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
|
||||
|
||||
if isinstance(images, ImageObject):
|
||||
images = [images]
|
||||
|
||||
byte_images = []
|
||||
for image in images:
|
||||
with BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
buffer.seek(0)
|
||||
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
byte_images.append(byte_image)
|
||||
|
||||
return byte_images
|
||||
|
||||
|
||||
def pad_images_batch_to_max_size(images):
|
||||
max_width = max(image.size[0] for image in images)
|
||||
max_height = max(image.size[1] for image in images)
|
||||
|
||||
padded_images = []
|
||||
|
||||
for image in images:
|
||||
|
||||
width, height = image.size
|
||||
|
||||
padding_left = (max_width - width) // 2
|
||||
padding_top = (max_height - height) // 2
|
||||
|
||||
padded_image = Image.new("RGB", (max_width, max_height), (0, 0, 0))
|
||||
padded_image.paste(image, (padding_left, padding_top))
|
||||
|
||||
padded_images.append(padded_image)
|
||||
|
||||
return padded_images
|
|
@ -1,5 +1,12 @@
|
|||
import torch
|
||||
|
||||
VALID_VISION_MODELS = [
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
]
|
||||
# This registry is used to check if a model is a vision model that we have checked it works with AReaLite. As different vision models vary in their image processing, special tokens and keys, etc. We will add models to this registry as we test them.
|
||||
# If you want to add a new vision model, please make sure it works with AReaLite.
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import getpass
|
||||
import os
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import SaverConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
|
@ -56,6 +56,7 @@ class Saver:
|
|||
global_step: int,
|
||||
name: str = "default",
|
||||
tokenizer: PreTrainedTokenizerFast | None = None,
|
||||
processor: AutoProcessor | None = None,
|
||||
base_model_path: str | None = None,
|
||||
):
|
||||
if not self.freq_ctl.check(
|
||||
|
@ -76,6 +77,7 @@ class Saver:
|
|||
weight_format=weight_format,
|
||||
with_optim=with_optim,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
base_model_path=base_model_path,
|
||||
)
|
||||
engine.save(meta)
|
||||
|
|
|
@ -38,6 +38,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import VLMRequest
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.image import image2base64, pad_images_batch_to_max_size
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
|
||||
class VisionRLVRWorkflow(RLVRWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
processor: AutoProcessor,
|
||||
enable_thinking: bool,
|
||||
dump_dir: str | None = None,
|
||||
):
|
||||
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking, dump_dir)
|
||||
self.processor = processor
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
|
||||
padded_images = pad_images_batch_to_max_size(data["images"])
|
||||
|
||||
processed_input = self.processor(
|
||||
images=padded_images,
|
||||
text=data["messages"],
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = processed_input["input_ids"].tolist()[0]
|
||||
|
||||
n_samples = self.gconfig.n_samples
|
||||
|
||||
byte_images = image2base64(padded_images)
|
||||
|
||||
req = VLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
input_ids=input_ids,
|
||||
image_data=byte_images,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"].unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"].unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
if self.dump_dir is not None:
|
||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||
# Get the unique identifier for this prompt
|
||||
qid = None
|
||||
for key in ["query_id", "id", "qid"]:
|
||||
qid = data.get(key, None)
|
||||
if qid is not None:
|
||||
break
|
||||
qid = qid or uuid.uuid4().hex
|
||||
|
||||
# Dump rollout to file
|
||||
with open(
|
||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||
) as f:
|
||||
n_samples = self.gconfig.n_samples
|
||||
for i, (p, c, r, sl) in enumerate(
|
||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||
]
|
||||
)
|
||||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
|
@ -1,17 +1,20 @@
|
|||
# Rollout and Agentic RL
|
||||
|
||||
This guide shows you how to create custom rollout behaviors for RL training by building
|
||||
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
|
||||
problems until it finds the correct answer.
|
||||
This guide demonstrates how to customize rollout behavior for PPO training by
|
||||
implementing a multi-turn math agent that uses end-to-end reinforcement learning. Our
|
||||
example agent will continuously try to solve a math problem until it reaches the correct
|
||||
answer.
|
||||
|
||||
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
|
||||
## Approach: Using AReaLite (Recommended)
|
||||
|
||||
## Step 1: Define Your Workflow
|
||||
The complete implementation is placed at `arealite/workflow/multi_turn.py`.
|
||||
|
||||
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
|
||||
classes that might constrain your agent's capabilities, AReaLite captures all rollout
|
||||
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
|
||||
behavior however you need.
|
||||
### Step 1: Define Your Workflow
|
||||
|
||||
AReaLite takes a flexible approach to agent definition. Rather than using rigid `Agent`
|
||||
classes that might limit your agentic capabilities, AReaLite captures all rollout
|
||||
behavior in a `RolloutWorkflow` class. This design gives you complete freedom to
|
||||
customize your agent's behavior.
|
||||
|
||||
```python
|
||||
# arealite/api/workflow_api.py
|
||||
|
@ -26,8 +29,8 @@ class RolloutWorkflow:
|
|||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
The workflow exposes an `arun_episode` method that runs and collects data from a single
|
||||
episode. This method takes two key arguments:
|
||||
The workflow exposes a single `arun_episode` method that runs and collects data from a
|
||||
single episode. This method takes two key arguments:
|
||||
|
||||
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
|
||||
inputs
|
||||
|
@ -36,19 +39,14 @@ episode. This method takes two key arguments:
|
|||
Within this method, you have complete control over how your agent and environment
|
||||
interact.
|
||||
|
||||
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
|
||||
> generated from that prompt—it's not batched. However, you can generate multiple
|
||||
> trajectories from a single prompt (for example, with GRPO or tree search).
|
||||
|
||||
### Setting Up the Multi-Turn Math Workflow
|
||||
#### Setting Up the Multi-Turn Math Workflow
|
||||
|
||||
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
|
||||
the `__init__` method to set up what we need during rollout:
|
||||
the `__init__` method to capture the utilities we need during rollout:
|
||||
|
||||
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
|
||||
> whatever arguments you need to construct your workflow. If you want to use tools, pass
|
||||
> the corresponding environment here so your agent can call it in the `arun_episode`
|
||||
> method.
|
||||
> **Note**: You have complete flexibility in defining the `__init__` method. Pass any
|
||||
> arguments needed to construct your workflow. If you want to use tools, pass the
|
||||
> corresponding environment here so your agent can call it in the `arun_episode` method.
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
|
@ -68,7 +66,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
self.turn_discount = turn_discount
|
||||
```
|
||||
|
||||
### Implementing the Episode Logic
|
||||
#### Implementing the Episode Logic
|
||||
|
||||
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
|
||||
and converting it into an `LLMRequest` object for the inference engine:
|
||||
|
@ -102,20 +100,20 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
# ... continue processing ...
|
||||
```
|
||||
|
||||
> **Note**: This example uses the "messages" key from the prompt data to get
|
||||
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
|
||||
> **Note**: This example accesses the "messages" key from the prompt data to get
|
||||
> OpenAI-compatible messages. This isn't mandatory—the key and prompt format depend
|
||||
> entirely on your implementation. For instance, if your dataset stores prompt strings
|
||||
> in a "prompt" column, you could get input token IDs with
|
||||
> in a "prompt" column, you could get input token IDs via
|
||||
> `self.tokenizer.encode(data["prompt"])`.
|
||||
|
||||
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
|
||||
> will reuse the LLM inference server's KV caches for better efficiency.
|
||||
> will reuse the LLM inference server's KV caches for efficiency.
|
||||
|
||||
### Handling Multi-Turn Conversations
|
||||
#### Handling Multi-Turn Conversations
|
||||
|
||||
Next, we'll check if the current answer is correct using our `reward_fn`. This function
|
||||
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
|
||||
apply a discount, add feedback to the conversation, and let the model try again:
|
||||
Next, we'll evaluate whether the current answer is correct using our `reward_fn`. This
|
||||
function should return 1 for correct answers and 0 otherwise. When the answer is wrong,
|
||||
we'll apply a discount, add feedback to the conversation, and let the model try again:
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
|
@ -150,10 +148,10 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
discount *= self.turn_discount
|
||||
```
|
||||
|
||||
### Reward Function Signature
|
||||
#### Reward Function Signature
|
||||
|
||||
To make it easier to switch between different reward functions, we recommend following
|
||||
this signature:
|
||||
For convenience when switching between different reward functions, we recommend
|
||||
following this pre-defined signature:
|
||||
|
||||
```python
|
||||
def reward_fn(
|
||||
|
@ -180,10 +178,10 @@ def reward_fn(
|
|||
"""
|
||||
```
|
||||
|
||||
While this signature is convenient, you're not restricted to it in custom
|
||||
workflows—modify as needed for your specific use case.
|
||||
While this signature is convenient, there are no strict restrictions on reward functions
|
||||
in custom workflows—modify them as needed for your specific use case.
|
||||
|
||||
### Collecting Training Data
|
||||
#### Collecting Training Data
|
||||
|
||||
Finally, let's complete the implementation by collecting trajectories in the
|
||||
`TensorDict` format:
|
||||
|
@ -218,6 +216,191 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
return concat_padded_tensors([res])
|
||||
```
|
||||
|
||||
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
|
||||
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
|
||||
> to automatically batch multiple trajectories for the training engine. Since this
|
||||
> example returns a single trajectory, we use `unsqueeze(0)` to create a size-1 batch.
|
||||
|
||||
> **Note**: There are no restrictions on the keys in your `TensorDict`—different
|
||||
> algorithms require different keys. This example targets the GRPO algorithm, so we
|
||||
> include `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for
|
||||
> computing importance ratios).
|
||||
|
||||
### Step 2: Training with Your Custom Workflow
|
||||
|
||||
Using your custom workflow is straightforward—just construct it in your training script
|
||||
and pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
```
|
||||
|
||||
That's it! Your custom multi-turn math agent is now ready to train with reinforcement
|
||||
learning. The workflow will automatically handle the multi-turn conversations, reward
|
||||
computation, and data collection needed for effective RL training.
|
||||
|
||||
## Alternative Approach: Using the Legacy Version (Not Recommended)
|
||||
|
||||
While we strongly recommend using AReaLite for new projects, you might encounter legacy
|
||||
code that uses the older Agent-based approach. Here's how it works for reference, though
|
||||
we suggest migrating to the workflow-based system when possible.
|
||||
|
||||
### Step 1: Define Your Agent Class
|
||||
|
||||
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
|
||||
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
|
||||
requires a single method: `collect_trajectory`.
|
||||
|
||||
```python
|
||||
class MathMultiTurnAgent(Agent):
|
||||
async def collect_trajectory(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters, # aka sampling_params
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
max_turns: int,
|
||||
turn_discount: float,
|
||||
):
|
||||
# Implementation goes here
|
||||
...
|
||||
```
|
||||
|
||||
### Step 2: Implement the Trajectory Collection Logic
|
||||
|
||||
The `collect_trajectory` method takes a task prompt, an environment, and two
|
||||
communication queues. Within this method, you control the data flow between your agent
|
||||
and the inference engine using these queues:
|
||||
|
||||
- **obs_queue**: Send observations (token IDs and generation config) to the inference
|
||||
engine
|
||||
- **act_queue**: Receive actions (generated responses) from the inference engine
|
||||
|
||||
Here's how the multi-turn conversation works:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# Send the current state to the inference engine
|
||||
await obs_queue.put((qid, token_ids, self.gconfig))
|
||||
|
||||
# Get the generated response
|
||||
act: BundledGenerationOutputs = await act_queue.get()
|
||||
|
||||
# Evaluate the response through the environment
|
||||
success, rewards = await env.step((qid, answers))
|
||||
# ... process results ...
|
||||
```
|
||||
|
||||
#### Environment Integration
|
||||
|
||||
The environment follows a
|
||||
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
|
||||
`step` methods, but uses asynchronous implementations to prevent blocking across
|
||||
different environment instances.
|
||||
|
||||
For math problems, the environment is typically stateless and acts as a wrapper around
|
||||
your reward function:
|
||||
|
||||
```python
|
||||
class MathCodeSingleStepEnv(EnvironmentService):
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
# ... setup code ...
|
||||
|
||||
# Run reward computation asynchronously
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
answers,
|
||||
# ... other parameters ...
|
||||
)
|
||||
return None, format_rewards, True, False, {}
|
||||
```
|
||||
|
||||
#### Handling Multi-Turn Feedback
|
||||
|
||||
After receiving the reward from `env.step`, check if the answer is correct. If not,
|
||||
provide feedback and continue to the next turn:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# ... generation and evaluation code ...
|
||||
|
||||
# Provide feedback based on the result
|
||||
if success[0]:
|
||||
feedback = "Congratulations! You are correct!"
|
||||
else:
|
||||
feedback = "Unfortunately your answer is wrong. Let's try again."
|
||||
|
||||
# Format feedback as a user message
|
||||
feedback = "\n" + self.tokenizer.apply_chat_template(
|
||||
[{"content": feedback, "role": "user"}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
# Add feedback tokens to the conversation
|
||||
feedback_tokens = self.tokenizer(feedback)["input_ids"]
|
||||
token_ids.extend(feedback_tokens)
|
||||
```
|
||||
|
||||
### Step 3: Register and Configure Your Agent
|
||||
|
||||
First, register your agent implementation:
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# ... episode logic above ...
|
||||
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# ... generation and evaluation ...
|
||||
|
||||
# Collect trajectory data
|
||||
input_len = len(resp.input_tokens) - len(seq)
|
||||
seq += resp.input_tokens[-input_len:] + resp.output_tokens
|
||||
logprobs += [0.0] * input_len + resp.output_logprobs
|
||||
loss_mask += [0] * input_len + [1] * resp.output_len
|
||||
versions += [-1] * input_len + resp.output_versions
|
||||
|
||||
# Package results
|
||||
res = dict(
|
||||
input_ids=torch.tensor(seq),
|
||||
logprobs=torch.tensor(logprobs),
|
||||
loss_mask=torch.tensor(loss_mask),
|
||||
versions=torch.tensor(versions),
|
||||
rewards=torch.tensor(float(reward * discount)),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool),
|
||||
)
|
||||
res = {k: v.unsqueeze(0) for k, v in res.items()}
|
||||
return concat_padded_tensors([res])
|
||||
```
|
||||
|
||||
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
|
||||
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
|
||||
> to automatically batch multiple trajectories for training. Since this example returns
|
||||
|
@ -234,34 +417,62 @@ Using your custom workflow is straightforward—just create it in your training
|
|||
pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
# in realhf/impl/agent/__init__.py
|
||||
import realhf.impl.agent.math_multi_turn_agent
|
||||
```
|
||||
|
||||
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
|
||||
training. The workflow will automatically handle the multi-turn conversations, reward
|
||||
computation, and data collection needed for effective RL training.
|
||||
Then update your experiment configuration in
|
||||
`realhf/experiments/async_exp/async_math_ppo.py`:
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass
|
||||
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||
# Add any new CLI arguments your agent needs
|
||||
my_param: float = 1.0
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentAbstraction:
|
||||
return AgentAbstraction(
|
||||
"math-multi-turn", # Your registered agent name
|
||||
args=dict(
|
||||
# Pass any arguments needed for your __init__ method
|
||||
my_param=self.my_param,
|
||||
# ... other configuration ...
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def env(self) -> EnvServiceAbstraction:
|
||||
# Update to use your custom environment if needed
|
||||
return EnvServiceAbstraction(
|
||||
"math-code-single-step",
|
||||
args=dict(dataset_path=self.dataset.path)
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Run Training
|
||||
|
||||
Follow the standard training procedure outlined in the
|
||||
[quickstart guide](../tutorial/quickstart.md). Launch your experiment with:
|
||||
|
||||
```bash
|
||||
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
|
||||
```
|
||||
|
||||
### Training Results
|
||||
|
||||
Here's an example of the training reward curve from our multi-turn math agent:
|
||||
|
||||

|
||||
|
||||
The agent successfully learns to solve math problems with improved accuracy over time,
|
||||
demonstrating the effectiveness of the multi-turn approach.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
**Note**: While this legacy approach works, we strongly recommend using the AReaLite
|
||||
workflow system for new projects. It provides better flexibility, cleaner abstractions,
|
||||
and easier maintenance. Consider migrating existing legacy agents to the workflow-based
|
||||
approach when possible.
|
||||
|
||||
Happy coding!
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from torch.utils.data import Subset
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
train_size = len(train_dataset)
|
||||
subset_size = int(1.0 * train_size)
|
||||
|
||||
random_indices = torch.randperm(train_size).tolist()[:subset_size]
|
||||
|
||||
subset_train_dataset = Subset(train_dataset, random_indices)
|
||||
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
subset_train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.valid_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
eval_rollout = RemoteSGLangEngine(config.rollout)
|
||||
eval_rollout.initialize(None, ft_spec)
|
||||
# NOTE: set a large version such that eval does not have any offpolicyness control
|
||||
eval_rollout.set_version(int(1e12))
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# NOTE: Weight update meta only requires address and free port of rank 0,
|
||||
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
||||
# due to `engine.get_param_specs()`.
|
||||
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||
weight_update_meta = [WeightUpdateMeta.from_disk(config.saver)]
|
||||
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||
weight_update_meta = weight_update_meta[0]
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
workflow = VisionRLVRWorkflow(
|
||||
reward_fn=clevr_count_70k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier(device_ids=[actor.device.index])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("recompute logp")
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
log_gpu_stats("compute advantages")
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"final_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
rollout.pause()
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(weight_update_meta)
|
||||
actor.upload_weights(weight_update_meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
dist.barrier(device_ids=[actor.device.index])
|
||||
torch.cuda.synchronize()
|
||||
rollout.resume()
|
||||
actor.set_version(global_step + 1)
|
||||
rollout.set_version(global_step + 1)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"):
|
||||
|
||||
def evaluate_fn():
|
||||
rollout.pause()
|
||||
cnt = 0
|
||||
for data in valid_dataloader:
|
||||
for item in data:
|
||||
eval_rollout.submit(item, workflow)
|
||||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"].float().to(actor.device)
|
||||
wandb.log({"eval_reward": rewards.mean().item()})
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
rewards.shape[0],
|
||||
device=rewards.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
)
|
||||
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
|
||||
rollout.resume()
|
||||
|
||||
evaluator.evaluate(
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
|
||||
logger.close()
|
||||
eval_rollout.destroy()
|
||||
rollout.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
actor.destroy()
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
|
@ -0,0 +1,124 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def main_sft():
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=pad_sequences_to_tensors,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
collate_fn=pad_sequences_to_tensors,
|
||||
drop_last=config.valid_dataset.drop_last,
|
||||
)
|
||||
|
||||
# Initialize engine
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
engine = FSDPLMEngine(config=config.model)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(train_dataloader):
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("sft"),
|
||||
):
|
||||
stats = engine.train_lm(data)
|
||||
engine.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(engine, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
def evaluate_fn():
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
for data in valid_dataloader:
|
||||
engine.evaluate_lm(data)
|
||||
|
||||
evaluator.evaluate(
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
stats_tracker.export(reduce_group=engine.parallelism_group),
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
logger.close()
|
||||
engine.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sft()
|
|
@ -0,0 +1,139 @@
|
|||
experiment_name: clevr_count_70k-grpo
|
||||
trial_name: trial1
|
||||
|
||||
|
||||
seed: 1
|
||||
total_train_epochs: 3
|
||||
tokenizer_path: ${actor.path}
|
||||
async_training: true
|
||||
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/clevr_count_70k-grpo
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
|
||||
allocation_mode: sglang.d1p1t1+d7p1t1
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: 256
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: false
|
||||
|
||||
gconfig:
|
||||
n_samples: 4
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 512
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: false
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-6
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 1
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.8
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 32
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: rl
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 32
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: rl
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
|
@ -0,0 +1,81 @@
|
|||
experiment_name: clevr_count_70k-sft
|
||||
trial_name: trial0
|
||||
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 1
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/clevr_count_70k-sft
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
seed: 1
|
||||
total_train_epochs: 3
|
||||
tokenizer_path: ${model.path}
|
||||
|
||||
model:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen2-VL-7B
|
||||
dtype: bfloat16
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1e-5
|
||||
lr_scheduler_type: cosine
|
||||
gradient_clipping: 1.0
|
||||
backend: fsdp
|
||||
|
||||
train_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: sft
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: sft
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: 1
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
|
@ -2,9 +2,9 @@ experiment_name: gsm8k-grpo
|
|||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
cluster:
|
||||
fileroot: /tmp/arealite/experiments
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
fileroot: /tmp/arealite/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
|
@ -90,11 +90,17 @@ train_dataset:
|
|||
batch_size: 256
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: rl
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 256
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: rl
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -125,5 +131,4 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
experiment_name: gsm8k-sft
|
||||
trial_name: trial0
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
|
@ -34,11 +35,17 @@ train_dataset:
|
|||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: sft
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: sft
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -69,5 +76,3 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
|
@ -0,0 +1,125 @@
|
|||
import math
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def convert_image(
|
||||
image: Union[Dict[str, Any], ImageObject, str],
|
||||
max_pixels: Optional[int],
|
||||
) -> ImageObject:
|
||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
with BytesIO() as output:
|
||||
image.save(output, format="JPEG")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||
"""
|
||||
"clevr_count_70k": {
|
||||
"image_key": "images",
|
||||
"question_key": "problem",
|
||||
"answer_key": "answer"
|
||||
},
|
||||
"""
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
def process_example(example, idx):
|
||||
# Add query_id column
|
||||
images = example["images"]
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
example["problem"] = (
|
||||
example["problem"].replace("<image>", image_token).replace("different", "")
|
||||
)
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(convert_image(image, 336 * 336))
|
||||
example["images"] = processed_images
|
||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda example, idx: process_example(example, idx),
|
||||
with_indices=True,
|
||||
)
|
||||
|
||||
def _process(example):
|
||||
text = example["seq"]
|
||||
processed_input = processor(
|
||||
text=[text],
|
||||
images=example["images"],
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
example["input_ids"] = processed_input["input_ids"].squeeze(0)
|
||||
example["pixel_values"] = processed_input["pixel_values"]
|
||||
example["image_grid_thw"] = processed_input["image_grid_thw"]
|
||||
answer_token = tokenizer.encode(example["answer"])
|
||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
|
||||
answer_token
|
||||
)
|
||||
example["loss_mask"] = loss_mask
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
processed_images = [
|
||||
convert_image(image, 336 * 336) for image in sample["images"]
|
||||
]
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
system_prompt = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
|
||||
),
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sample["problem"]
|
||||
.replace("<image>", image_token)
|
||||
.replace("different", ""),
|
||||
}
|
||||
]
|
||||
messages.insert(0, system_prompt)
|
||||
messages = processor.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
return {"messages": messages, "images": processed_images}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["problem"])
|
||||
return dataset
|
|
@ -0,0 +1,30 @@
|
|||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
|
||||
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
|
||||
return {"input_ids": seq_token, "loss_mask": loss_mask}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_rl_dataset(path, split, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
|
@ -3,12 +3,11 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
|
@ -22,21 +21,6 @@ from realhf.base import logging, seeding, stats_tracker
|
|||
logger = logging.getLogger("GSM8K grpo")
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
|
@ -52,10 +36,26 @@ def main(args):
|
|||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
|
@ -63,7 +63,7 @@ def main(args):
|
|||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", rank, world_size),
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
|
|
|
@ -7,6 +7,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
|||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
|
@ -16,25 +17,6 @@ from realhf.api.core.data_api import load_hf_tokenizer
|
|||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
|
||||
return {"input_ids": seq_token, "loss_mask": loss_mask}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
|
||||
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, SFTConfig)
|
||||
config: SFTConfig
|
||||
|
@ -43,9 +25,26 @@ def main(args):
|
|||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", tokenizer, rank, world_size),
|
||||
train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
|
@ -53,7 +52,7 @@ def main(args):
|
|||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", tokenizer, rank, world_size),
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
|
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers==4.53.0",
|
||||
"transformers==4.53.1",
|
||||
|
||||
# Scientific computing
|
||||
"scipy",
|
||||
|
|
|
@ -69,6 +69,27 @@ def load_hf_tokenizer(
|
|||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def load_hf_processor_and_tokenizer(
|
||||
model_name_or_path: str,
|
||||
fast_tokenizer=True,
|
||||
padding_side: Optional[str] = None,
|
||||
) -> Tuple[transformers.AutoProcessor, transformers.PreTrainedTokenizerFast]:
|
||||
"""Load a tokenizer and processor from Hugging Face."""
|
||||
tokenizer = load_hf_tokenizer(model_name_or_path, fast_tokenizer, padding_side)
|
||||
try:
|
||||
processor = transformers.AutoProcessor.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True, force_download=True
|
||||
)
|
||||
except Exception:
|
||||
processor = None
|
||||
logger.warning(
|
||||
f"Failed to load processor for {model_name_or_path}. "
|
||||
"Using tokenizer only. This may cause issues with some models."
|
||||
)
|
||||
return processor, tokenizer
|
||||
|
||||
|
||||
@pdclasses.dataclass
|
||||
class SequenceSplitSpec:
|
||||
partitions: Optional[List[Tuple[int, int]]] = None
|
||||
|
|
|
@ -35,6 +35,7 @@ def load_hf_or_local_file(path: str) -> str:
|
|||
=>
|
||||
/root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl
|
||||
"""
|
||||
path = str(path)
|
||||
if path.startswith("hf://") or path.startswith("hf-dataset://"):
|
||||
repo_type = "dataset" if path.startswith("hf-dataset://") else "model"
|
||||
hf_path = path.strip().split("://")[1]
|
||||
|
|
Loading…
Reference in New Issue