Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AReaL into lcy/refactor

This commit is contained in:
lichangye.lcy 2025-07-23 14:44:36 +08:00
commit 391bd85e44
20 changed files with 519 additions and 170 deletions

95
.clang-format Normal file
View File

@ -0,0 +1,95 @@
---
Language: Cpp
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: true
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
BreakBeforeInheritanceComma: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Right
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
TabWidth: 8
UseTab: Never
...

183
.dockerignore Normal file
View File

@ -0,0 +1,183 @@
# Legacy codes
.legacy/
.data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
trace_result/
profile_result/
slurm_outs
_data
*.nfs*
output
logs
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
# dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# openai api key
api_key.txt
api_key.json
./*.sh
*.png
*.jpg
*.pdf
.vscode/

View File

@ -690,7 +690,7 @@ class ClusterSpecConfig:
@dataclass
class DatasetConfig:
path: str =field(
path: str = field(
default=MISSING,
metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."

View File

@ -48,13 +48,16 @@ class LLMResponse:
def output_len(self) -> int:
return len(self.output_tokens)
@dataclass
class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
@dataclass
class VLMResponse(LLMResponse):
input_images: List[ImageObject|str] = field(default_factory=list)
input_images: List[ImageObject | str] = field(default_factory=list)
@dataclass
class FinetuneSpec:
@ -142,6 +145,7 @@ class AllocationMode:
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile(

View File

@ -4,36 +4,39 @@ import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset(
path: str,
rank: int,
world_size: int,
training_type: str= "sft",
training_type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
):
):
if "gsm8k" in path and training_type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
elif "gsm8k" in path and training_type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)

View File

@ -9,8 +9,8 @@ from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
@ -31,8 +31,8 @@ from arealite.utils.data import (
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model,VALID_VISION_MODELS
from realhf.api.core.data_api import load_hf_tokenizer, load_hf_processor_and_tokenizer
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine")
@ -97,7 +97,9 @@ class BaseHFEngine(TrainEngine):
if self.is_vision_model:
dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
tik = time.perf_counter()
with torch.device("cuda"):
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model:
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
)
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
loss *= loss_scale
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,

View File

@ -2,7 +2,7 @@ import os
import time
from datetime import datetime
from typing import Callable, Dict, Optional, Tuple
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
import torch
import torch.distributed as dist
from tensordict import TensorDict
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast,AutoProcessor
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
@ -26,6 +26,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
)
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
@ -49,7 +50,6 @@ class FSDPEngine(BaseHFEngine):
self.create_process_group()
self.create_device_model()
# Wrap with FSDP2
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
@ -100,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
self.load_optimizer_state(meta.path)
def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
):
"""Save model in HuggingFace format."""
if self.model is None:
@ -146,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
dist.barrier()
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
@ -240,7 +243,6 @@ class FSDPEngine(BaseHFEngine):
loss *= loss_scale
loss.backward()
# NOTE: grad norm clip function is different
grad_norm = fsdp2_clip_grad_norm_(

View File

@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,

View File

@ -5,7 +5,6 @@ from tensordict import TensorDict
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
# from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
from arealite.utils.functional import gather_logprobs
from realhf.base import stats_tracker
@ -42,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]

View File

@ -23,9 +23,9 @@ from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
RolloutStat,
WeightUpdateMeta,
VLMRequest,
VLMResponse
VLMResponse,
WeightUpdateMeta,
)
from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
gconfig = req.gconfig
@ -331,7 +333,7 @@ class RemoteSGLangEngine(InferenceEngine):
ttft=latency, # Simplified for non-streaming
)
else:
response=LLMResponse(
response = LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,

View File

@ -65,7 +65,8 @@ def pad_sequences_to_tensors(
return TensorDict()
skip_keys = {"pixel_values", "image_grid_thw"}
max_length = max(
len(seq) for item in sequence_list
len(seq)
for item in sequence_list
for key, seq in item.items()
if key not in skip_keys
)
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
x = item[key]
if not torch.is_tensor(x):
x = torch.tensor(x)
padded_x=torch.nn.functional.pad(
padded_x = torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded.append(padded_x)
result[key] = torch.stack(padded)
attention_mask = [
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
+ [0]
* (
max_length
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
)
for item in sequence_list
]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
@ -139,7 +144,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if key == "pixel_values" or key== "image_grid_thw":
if key == "pixel_values" or key == "image_grid_thw":
tensors_to_concat.append(tensor)
continue
if current_length < max_length:
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
to_split = {}
not_to_split = {}
for key, value in data.items():
if key=="image_grid_thw" or key=="pixel_values":
if key == "image_grid_thw" or key == "pixel_values":
continue
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
for group_index in group_indices:
group_pixel_values = [pixel_values[i] for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
pixel_values_split.append(torch.stack(group_pixel_values))

View File

@ -7,7 +7,7 @@ from typing import List
from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
if isinstance(images, ImageObject):
images = [images]
@ -17,10 +17,7 @@ def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
with BytesIO() as buffer:
image.save(buffer, format="PNG")
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
byte_images.append(byte_image)
return byte_images

View File

@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
"qwen2_5_vl",
]
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():

View File

@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
return_tensors="pt",
)
input_ids=processed_input["input_ids"].tolist()[0]
input_ids = processed_input["input_ids"].tolist()[0]
n_samples = self.gconfig.n_samples
@ -67,8 +67,14 @@ class VL_RLVRWorkflow(RLVRWorkflow):
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
pixel_values=processed_input["pixel_values"]
.clone()
.detach()
.unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"]
.clone()
.detach()
.unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),

View File

@ -1,13 +1,15 @@
import os
import re
import sys
import wandb
import torch
import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -17,7 +19,6 @@ from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def extract_answer(pred_str, data_name, use_last_number=True):
@ -55,11 +56,14 @@ def extract_solution(solution_str, method="strict") -> str | None:
break
return final_answer
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
is_thinking = "thinking" in completions.lower()
sol = extract_answer(completions, data_name="") # str number
ans =answer
ans = answer
if sol is None:
return 0
@ -80,7 +84,7 @@ def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, a
def main(args):
os.environ["WANDB_API_KEY"]=""
os.environ["WANDB_API_KEY"] = ""
wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig)
@ -89,21 +93,21 @@ def main(args):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
processor=processor
processor=processor,
)
valid_dataset=get_custom_dataset(
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
processor=processor
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
@ -141,7 +145,7 @@ def main(args):
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref =FSDPPPOActor(config=config.ref)
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
@ -211,7 +215,7 @@ def main(args):
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
@ -251,7 +255,14 @@ def main(args):
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
wandb.log(
{
"eval_reward": rewards.mean().item(),
"epoch": epoch,
"step": step,
"global_step": global_step,
}
)
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
@ -280,6 +291,6 @@ def main(args):
actor.destroy()
wandb.finish()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -1,9 +1,11 @@
import os
import sys
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
@ -11,7 +13,6 @@ from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def main_sft():
@ -21,7 +22,7 @@ def main_sft():
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
@ -30,7 +31,7 @@ def main_sft():
tokenizer=tokenizer,
processor=processor,
)
valid_dataset=get_custom_dataset(
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,

View File

@ -1,20 +1,33 @@
from typing import Any, Dict, Optional, Union
import base64
import math
from PIL.Image import Image as ImageObject
import os
from io import BytesIO
from typing import Any, Dict, Optional, Union
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
import base64
from io import BytesIO
from PIL.Image import Image as ImageObject
def input_text(text:str):
def input_text(text: str):
return {"type": "input_text", "text": text}
def input_image(base64_image: str):
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
return {
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
}
def build_raw_message(
sample: Dict[str, Any], base64_images: list[str]
) -> list[Dict[str, Any]]:
raw_message = []
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
problem_parts = [
part.strip() for part in sample["problem"].split("<image>") if part.strip()
]
insert_list = []
for i, part in enumerate(problem_parts):
if i > 0 or sample["problem"].startswith("<image>"):
@ -38,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
def encode_image(image_file):
return base64.b64encode(image_file).decode("utf-8")
def convert_image(
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
image: Union[Dict[str, Any], ImageObject, str],
min_pixels: Optional[int],
max_pixels: Optional[int],
) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if min_pixels is not None and (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if image.mode != "RGB":
@ -57,29 +78,31 @@ def convert_image(
image.save(output, format="JPEG")
return output.getvalue()
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
'''
"""
"clevr_count_70k": {
"image_key": "images",
"question_key": "problem",
"answer_key": "answer"
},
'''
"""
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer
def process_example(example, idx):
# Add query_id column
images = example["images"]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token)
processed_images = []
for image in images:
processed_images.append(convert_image(image,113*113,336*336))
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
@ -91,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
)
def _process(example):
text=example["seq"]
processed_input=processor(
text = example["seq"]
processed_input = processor(
text=[text],
images=example["images"],
padding=False,
@ -101,37 +124,51 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
return_attention_mask=False,
)
example["input_ids"] =processed_input["input_ids"].squeeze(0)
example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
example["loss_mask"]=loss_mask
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
answer_token
)
example["loss_mask"] = loss_mask
return example
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
processed_images = [
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
system_prompt = {
"role": "system",
"content": (
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
)
),
}
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token),
}
]
messages.insert(0, system_prompt)
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images}
dataset = dataset.map(process).remove_columns(["problem"])

View File

@ -1,9 +1,11 @@
from datasets import load_dataset, Dataset
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
@ -15,9 +17,11 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_rl_dataset(path,split, rank, world_size):
def get_gsm8k_rl_dataset(path, split, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}

View File

@ -68,6 +68,7 @@ def load_hf_tokenizer(
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
@lru_cache(maxsize=8)
def load_hf_processor_and_tokenizer(
model_name_or_path: str,