Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AReaL into lcy/refactor

This commit is contained in:
lichangye.lcy 2025-07-23 14:44:36 +08:00
commit 391bd85e44
20 changed files with 519 additions and 170 deletions

95
.clang-format Normal file
View File

@ -0,0 +1,95 @@
---
Language: Cpp
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: true
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
BreakBeforeInheritanceComma: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Right
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
TabWidth: 8
UseTab: Never
...

183
.dockerignore Normal file
View File

@ -0,0 +1,183 @@
# Legacy codes
.legacy/
.data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
trace_result/
profile_result/
slurm_outs
_data
*.nfs*
output
logs
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
# dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# openai api key
api_key.txt
api_key.json
./*.sh
*.png
*.jpg
*.pdf
.vscode/

View File

@ -690,7 +690,7 @@ class ClusterSpecConfig:
@dataclass
class DatasetConfig:
path: str =field(
path: str = field(
default=MISSING,
metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."

View File

@ -47,14 +47,17 @@ class LLMResponse:
@property
def output_len(self) -> int:
return len(self.output_tokens)
@dataclass
class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
@dataclass
class VLMResponse(LLMResponse):
input_images: List[ImageObject|str] = field(default_factory=list)
input_images: List[ImageObject | str] = field(default_factory=list)
@dataclass
class FinetuneSpec:
@ -142,6 +145,7 @@ class AllocationMode:
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile(

View File

@ -4,36 +4,39 @@ import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset(
path: str,
rank: int,
world_size: int,
training_type: str= "sft",
training_type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
):
):
if "gsm8k" in path and training_type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
elif "gsm8k" in path and training_type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)

View File

@ -9,8 +9,8 @@ from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
@ -31,8 +31,8 @@ from arealite.utils.data import (
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model,VALID_VISION_MODELS
from realhf.api.core.data_api import load_hf_tokenizer, load_hf_processor_and_tokenizer
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine")
@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
self.own_global_group = False
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
dtype = getattr(torch, self.config.dtype)
if self.is_vision_model:
dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
tik = time.perf_counter()
with torch.device("cuda"):
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model:
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
)
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
loss *= loss_scale
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,

View File

@ -2,7 +2,7 @@ import os
import time
from datetime import datetime
from typing import Callable, Dict, Optional, Tuple
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
import torch
import torch.distributed as dist
from tensordict import TensorDict
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast,AutoProcessor
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
@ -26,6 +26,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
)
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
@ -48,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
self.create_process_group()
self.create_device_model()
# Wrap with FSDP2
# Simple auto wrap policy
@ -100,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
self.load_optimizer_state(meta.path)
def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
):
"""Save model in HuggingFace format."""
if self.model is None:
@ -146,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
dist.barrier()
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
@ -239,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
loss *= loss_scale
loss.backward()
# NOTE: grad norm clip function is different

View File

@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,

View File

@ -5,7 +5,6 @@ from tensordict import TensorDict
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
# from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
from arealite.utils.functional import gather_logprobs
from realhf.base import stats_tracker
@ -42,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
@ -50,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
logprobs = torch.where(loss_mask, logprobs, 0)
loss = -logprobs.sum() / loss_mask.count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(

View File

@ -23,9 +23,9 @@ from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
RolloutStat,
VLMRequest,
VLMResponse,
WeightUpdateMeta,
VLMRequest,
VLMResponse
)
from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
gconfig = req.gconfig
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time
if isinstance(req, VLMRequest):
response = VLMResponse(
input_tokens=req.input_ids,
input_images=req.image_data,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
input_tokens=req.input_ids,
input_images=req.image_data,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
else:
response=LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
response = LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
return response
def update_weights(self, meta):
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)

View File

@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
return TensorDict()
skip_keys = {"pixel_values", "image_grid_thw"}
max_length = max(
len(seq) for item in sequence_list
for key, seq in item.items()
len(seq)
for item in sequence_list
for key, seq in item.items()
if key not in skip_keys
)
result = {}
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
x = item[key]
if not torch.is_tensor(x):
x = torch.tensor(x)
padded_x=torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded_x = torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded.append(padded_x)
result[key] = torch.stack(padded)
attention_mask = [
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
+ [0]
* (
max_length
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
)
for item in sequence_list
]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
@ -139,7 +144,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if key == "pixel_values" or key== "image_grid_thw":
if key == "pixel_values" or key == "image_grid_thw":
tensors_to_concat.append(tensor)
continue
if current_length < max_length:
@ -150,7 +155,7 @@ def concat_padded_tensors(
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
to_split = {}
not_to_split = {}
for key, value in data.items():
if key=="image_grid_thw" or key=="pixel_values":
if key == "image_grid_thw" or key == "pixel_values":
continue
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
for group_index in group_indices:
group_pixel_values = [pixel_values[i] for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
pixel_values_split.append(torch.stack(group_pixel_values))

View File

@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm

View File

@ -7,20 +7,17 @@ from typing import List
from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
if isinstance(images, ImageObject):
images = [images]
byte_images = []
for image in images:
with BytesIO() as buffer:
image.save(buffer, format="PNG")
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
byte_images.append(byte_image)
return byte_images

View File

@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
"qwen2_5_vl",
]
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():

View File

@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
return_tensors="pt",
)
input_ids=processed_input["input_ids"].tolist()[0]
input_ids = processed_input["input_ids"].tolist()[0]
n_samples = self.gconfig.n_samples
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
pixel_values=processed_input["pixel_values"]
.clone()
.detach()
.unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"]
.clone()
.detach()
.unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),

View File

@ -1,13 +1,15 @@
import os
import re
import sys
import wandb
import torch
import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -17,13 +19,12 @@ from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return match[-1]
return ""
@ -55,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
break
return final_answer
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
is_thinking = "thinking" in completions.lower()
sol = extract_answer(completions, data_name="") # str number
ans =answer
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
if is_thinking:
return 1
return 1
else:
return 1
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
return 0.05
return 0
def main(args):
os.environ["WANDB_API_KEY"]=""
os.environ["WANDB_API_KEY"] = ""
wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig)
@ -89,22 +93,22 @@ def main(args):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
processor=processor
)
valid_dataset=get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
processor=processor
)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
processor=processor,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,
@ -141,7 +145,7 @@ def main(args):
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref =FSDPPPOActor(config=config.ref)
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
@ -199,7 +203,7 @@ def main(args):
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
@ -211,8 +215,8 @@ def main(args):
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
@ -251,7 +255,14 @@ def main(args):
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
wandb.log(
{
"eval_reward": rewards.mean().item(),
"epoch": epoch,
"step": step,
"global_step": global_step,
}
)
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
@ -280,6 +291,6 @@ def main(args):
actor.destroy()
wandb.finish()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -1,9 +1,11 @@
import os
import sys
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
@ -11,7 +13,6 @@ from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def main_sft():
@ -21,25 +22,25 @@ def main_sft():
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
valid_dataset=get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,

View File

@ -1,26 +1,39 @@
from typing import Any, Dict, Optional, Union
import base64
import math
from PIL.Image import Image as ImageObject
import os
from io import BytesIO
from typing import Any, Dict, Optional, Union
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
import base64
from io import BytesIO
from PIL.Image import Image as ImageObject
def input_text(text:str):
def input_text(text: str):
return {"type": "input_text", "text": text}
def input_image(base64_image: str):
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
return {
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
}
def build_raw_message(
sample: Dict[str, Any], base64_images: list[str]
) -> list[Dict[str, Any]]:
raw_message = []
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
problem_parts = [
part.strip() for part in sample["problem"].split("<image>") if part.strip()
]
insert_list = []
for i, part in enumerate(problem_parts):
if i > 0 or sample["problem"].startswith("<image>"):
if i > 0 or sample["problem"].startswith("<image>"):
insert_list.append("image")
part = part.strip()
if part:
part = part.strip()
if part:
insert_list.append("text")
image_index = 0
text_index = 0
@ -38,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
def encode_image(image_file):
return base64.b64encode(image_file).decode("utf-8")
def convert_image(
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
image: Union[Dict[str, Any], ImageObject, str],
min_pixels: Optional[int],
max_pixels: Optional[int],
) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if min_pixels is not None and (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if image.mode != "RGB":
@ -57,32 +78,34 @@ def convert_image(
image.save(output, format="JPEG")
return output.getvalue()
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
'''
"""
"clevr_count_70k": {
"image_key": "images",
"question_key": "problem",
"answer_key": "answer"
},
'''
"""
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer
tokenizer = processor.tokenizer
def process_example(example, idx):
# Add query_id column
images = example["images"]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token)
processed_images = []
for image in images:
processed_images.append(convert_image(image,113*113,336*336))
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
return example
dataset = dataset.map(
@ -91,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
)
def _process(example):
text=example["seq"]
processed_input=processor(
text = example["seq"]
processed_input = processor(
text=[text],
images=example["images"],
padding=False,
@ -101,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
return_attention_mask=False,
)
example["input_ids"] =processed_input["input_ids"].squeeze(0)
example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
example["loss_mask"]=loss_mask
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
answer_token
)
example["loss_mask"] = loss_mask
return example
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
processed_images = [
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
system_prompt = {
"role": "system",
"role": "system",
"content": (
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
)
),
}
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token),
}
]
messages.insert(0, system_prompt)
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images}
dataset = dataset.map(process).remove_columns(["problem"])
return dataset
return dataset

View File

@ -1,9 +1,11 @@
from datasets import load_dataset, Dataset
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
@ -15,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_rl_dataset(path,split, rank, world_size):
def get_gsm8k_rl_dataset(path, split, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
return dataset

View File

@ -68,6 +68,7 @@ def load_hf_tokenizer(
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
@lru_cache(maxsize=8)
def load_hf_processor_and_tokenizer(
model_name_or_path: str,