Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AReaL into lcy/refactor

This commit is contained in:
lichangye.lcy 2025-07-23 14:44:36 +08:00
commit 391bd85e44
20 changed files with 519 additions and 170 deletions

95
.clang-format Normal file
View File

@ -0,0 +1,95 @@
---
Language: Cpp
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: true
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
BreakBeforeInheritanceComma: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Right
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
TabWidth: 8
UseTab: Never
...

183
.dockerignore Normal file
View File

@ -0,0 +1,183 @@
# Legacy codes
.legacy/
.data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
trace_result/
profile_result/
slurm_outs
_data
*.nfs*
output
logs
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
# dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# openai api key
api_key.txt
api_key.json
./*.sh
*.png
*.jpg
*.pdf
.vscode/

View File

@ -690,7 +690,7 @@ class ClusterSpecConfig:
@dataclass @dataclass
class DatasetConfig: class DatasetConfig:
path: str =field( path: str = field(
default=MISSING, default=MISSING,
metadata={ metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name." "help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."

View File

@ -47,14 +47,17 @@ class LLMResponse:
@property @property
def output_len(self) -> int: def output_len(self) -> int:
return len(self.output_tokens) return len(self.output_tokens)
@dataclass @dataclass
class VLMRequest(LLMRequest): class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject|str]] = field(default_factory=list) image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
@dataclass @dataclass
class VLMResponse(LLMResponse): class VLMResponse(LLMResponse):
input_images: List[ImageObject|str] = field(default_factory=list) input_images: List[ImageObject | str] = field(default_factory=list)
@dataclass @dataclass
class FinetuneSpec: class FinetuneSpec:
@ -142,6 +145,7 @@ class AllocationMode:
raise ValueError( raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}" f"Unknown how to resolve parallelism strategy: {allocation_mode}"
) )
@staticmethod @staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict: def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile( pattern = re.compile(

View File

@ -4,36 +4,39 @@ import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"] VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset( def get_custom_dataset(
path: str, path: str,
rank: int, rank: int,
world_size: int, world_size: int,
training_type: str= "sft", training_type: str = "sft",
split: Optional[str] = None, split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None, processor: Optional[transformers.AutoProcessor] = None,
): ):
if "gsm8k" in path and training_type == "sft": if "gsm8k" in path and training_type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size) return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
elif "gsm8k" in path and training_type == "rl": elif "gsm8k" in path and training_type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size) return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft": elif "clevr_count_70k" in path and training_type == "sft":
from examples.arealite.dataset.clevr_count_70k import ( from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset, get_clevr_count_70k_sft_dataset,
) )
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size) return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl": elif "clevr_count_70k" in path and training_type == "rl":
from examples.arealite.dataset.clevr_count_70k import ( from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset, get_clevr_count_70k_rl_dataset,
) )
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
else: else:
raise ValueError( raise ValueError(
f"Dataset {path} with split {split} and training type {training_type} is not supported. " f"Dataset {path} with split {split} and training type {training_type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. " f"Supported datasets are: {VALID_DATASETS}. "
) )

View File

@ -9,8 +9,8 @@ from tensordict import TensorDict
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoProcessor,
AutoModelForImageTextToText, AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
get_constant_schedule_with_warmup, get_constant_schedule_with_warmup,
@ -31,8 +31,8 @@ from arealite.utils.data import (
unsqueeze_mb_list, unsqueeze_mb_list,
) )
from arealite.utils.fsdp import get_cosine_schedule_with_warmup from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model,VALID_VISION_MODELS from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from realhf.api.core.data_api import load_hf_tokenizer, load_hf_processor_and_tokenizer from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine") logger = logging.getLogger("Base HF Engine")
@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
self.own_global_group = False self.own_global_group = False
self._parallelism_group: dist.ProcessGroup self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False self.weight_update_group_initialized = False
self.model_config = AutoConfig.from_pretrained( self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path, pretrained_model_name_or_path=self.config.path,
trust_remote_code=True, trust_remote_code=True,
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
dtype = getattr(torch, self.config.dtype) dtype = getattr(torch, self.config.dtype)
if self.is_vision_model: if self.is_vision_model:
dtype = torch.bfloat16 dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path) self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
tik = time.perf_counter() tik = time.perf_counter()
with torch.device("cuda"): with torch.device("cuda"):
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
) )
if self.config.disable_dropout: if self.config.disable_dropout:
disable_dropout_in_model(model) disable_dropout_in_model(model)
if self.config.gradient_checkpointing: if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable( model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False} gradient_checkpointing_kwargs={"use_reentrant": False}
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
assert self.lr_scheduler is not None assert self.lr_scheduler is not None
self.lr_scheduler.step() self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList: def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_ assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model: if self.is_vision_model:
assert "pixel_values" in input_ and "image_grid_thw" in input_, ( assert (
"For vision-language models, pixel_values and image_grid_thw must be present in input_" "pixel_values" in input_ and "image_grid_thw" in input_
) ), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
if isinstance(input_, dict): if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]]) input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
loss *= loss_scale loss *= loss_scale
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.model.parameters(),
self.optimizer_config.gradient_clipping, self.optimizer_config.gradient_clipping,

View File

@ -2,7 +2,7 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from tensordict import TensorDict from tensordict import TensorDict
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions, StateDictOptions,
get_model_state_dict, get_model_state_dict,
) )
from transformers import PreTrainedTokenizerFast,AutoProcessor from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
@ -26,6 +26,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict, fsdp2_load_full_state_dict,
) )
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine") logger = logging.getLogger("FSDPEngine")
@ -48,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
self.create_process_group() self.create_process_group()
self.create_device_model() self.create_device_model()
# Wrap with FSDP2 # Wrap with FSDP2
# Simple auto wrap policy # Simple auto wrap policy
@ -100,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
self.load_optimizer_state(meta.path) self.load_optimizer_state(meta.path)
def _save_model_to_hf( def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor] self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
): ):
"""Save model in HuggingFace format.""" """Save model in HuggingFace format."""
if self.model is None: if self.model is None:
@ -146,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
dist.barrier() dist.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
elif meta.type == "disk": elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer,self.processor) self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished # dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0: if dist.get_rank() == 0:
update_name = names.update_weights_from_disk( update_name = names.update_weights_from_disk(
@ -239,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
loss *= loss_scale loss *= loss_scale
loss.backward() loss.backward()
# NOTE: grad norm clip function is different # NOTE: grad norm clip function is different

View File

@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
return self.actor.ppo_update(*args, **kwargs) return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn( def grpo_loss_fn(
logits: torch.Tensor, logits: torch.Tensor,
input_data: Dict, input_data: Dict,

View File

@ -5,7 +5,6 @@ from tensordict import TensorDict
from arealite.api.cli_args import TrainEngineConfig from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import TrainEngine from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine from arealite.engine.fsdp_engine import FSDPEngine
# from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
from arealite.utils.functional import gather_logprobs from arealite.utils.functional import gather_logprobs
from realhf.base import stats_tracker from realhf.base import stats_tracker
@ -42,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
def evaluate_lm(self, data): def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data) return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor: def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"] packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"] cu_seqlens: torch.Tensor = input_["cu_seqlens"]
@ -50,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1)) logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1) loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
logprobs = torch.where(loss_mask, logprobs, 0) logprobs = torch.where(loss_mask, logprobs, 0)
loss = -logprobs.sum() / loss_mask.count_nonzero() loss = -logprobs.sum() / loss_mask.count_nonzero()
with torch.no_grad(): with torch.no_grad():
seqlogp = torch.zeros( seqlogp = torch.zeros(

View File

@ -23,9 +23,9 @@ from arealite.api.io_struct import (
LLMRequest, LLMRequest,
LLMResponse, LLMResponse,
RolloutStat, RolloutStat,
VLMRequest,
VLMResponse,
WeightUpdateMeta, WeightUpdateMeta,
VLMRequest,
VLMResponse
) )
from arealite.utils.data import concat_padded_tensors from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector from arealite.utils.http import arequest_with_retry, get_default_connector
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
return server return server
raise NotImplementedError("Only round-robin scheduling is implemented.") raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse: async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp.""" """Async version of generate using aiohttp."""
# Prepare request payload # Prepare request payload
gconfig = req.gconfig gconfig = req.gconfig
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
sample_params["max_new_tokens"] -= len(output_tokens) sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time latency = time.perf_counter() - start_time
if isinstance(req, VLMRequest): if isinstance(req, VLMRequest):
response = VLMResponse( response = VLMResponse(
input_tokens=req.input_ids, input_tokens=req.input_ids,
input_images=req.image_data, input_images=req.image_data,
output_tokens=accumulated_output_tokens, output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs, output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions, output_versions=accumulated_versions,
stop_reason=stop_reason, stop_reason=stop_reason,
latency=latency, latency=latency,
ttft=latency, # Simplified for non-streaming ttft=latency, # Simplified for non-streaming
) )
else: else:
response=LLMResponse( response = LLMResponse(
input_tokens=req.input_ids, input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens, output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs, output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions, output_versions=accumulated_versions,
stop_reason=stop_reason, stop_reason=stop_reason,
latency=latency, latency=latency,
ttft=latency, # Simplified for non-streaming ttft=latency, # Simplified for non-streaming
) )
return response return response
def update_weights(self, meta): def update_weights(self, meta):
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
): ):
try: try:
data = next(self.data_generator) data = next(self.data_generator)
except StopIteration: except StopIteration:
self.data_generator = iter(dataloader) self.data_generator = iter(dataloader)
data = next(self.data_generator) data = next(self.data_generator)

View File

@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
return TensorDict() return TensorDict()
skip_keys = {"pixel_values", "image_grid_thw"} skip_keys = {"pixel_values", "image_grid_thw"}
max_length = max( max_length = max(
len(seq) for item in sequence_list len(seq)
for key, seq in item.items() for item in sequence_list
for key, seq in item.items()
if key not in skip_keys if key not in skip_keys
) )
result = {} result = {}
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
x = item[key] x = item[key]
if not torch.is_tensor(x): if not torch.is_tensor(x):
x = torch.tensor(x) x = torch.tensor(x)
padded_x=torch.nn.functional.pad( padded_x = torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value x, (0, max_length - len(item[key])), value=pad_value
) )
padded.append(padded_x) padded.append(padded_x)
result[key] = torch.stack(padded) result[key] = torch.stack(padded)
attention_mask = [ attention_mask = [
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys))) [1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))) + [0]
* (
max_length
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
)
for item in sequence_list for item in sequence_list
] ]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool) result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
@ -139,7 +144,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor) tensors_to_concat.append(tensor)
continue continue
current_length = tensor.shape[1] current_length = tensor.shape[1]
if key == "pixel_values" or key== "image_grid_thw": if key == "pixel_values" or key == "image_grid_thw":
tensors_to_concat.append(tensor) tensors_to_concat.append(tensor)
continue continue
if current_length < max_length: if current_length < max_length:
@ -150,7 +155,7 @@ def concat_padded_tensors(
padding = torch.zeros( padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype (tensor.shape[0], pad_width), dtype=tensor.dtype
) )
else: else:
# Pad feature tensors with pad_value # Pad feature tensors with pad_value
padding = torch.full( padding = torch.full(
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
to_split = {} to_split = {}
not_to_split = {} not_to_split = {}
for key, value in data.items(): for key, value in data.items():
if key=="image_grid_thw" or key=="pixel_values": if key == "image_grid_thw" or key == "pixel_values":
continue continue
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen: if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value not_to_split[key] = value
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
for group_index in group_indices: for group_index in group_indices:
group_pixel_values = [pixel_values[i] for i in group_index] group_pixel_values = [pixel_values[i] for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index] group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
# Stack pixel_values for each group (assuming pixel_values is a list of tensors) # Stack pixel_values for each group (assuming pixel_values is a list of tensors)
pixel_values_split.append(torch.stack(group_pixel_values)) pixel_values_split.append(torch.stack(group_pixel_values))

View File

@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
grads = [p.grad for p in parameters if p.grad is not None] grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True) total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm return total_norm

View File

@ -7,20 +7,17 @@ from typing import List
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str: def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
if isinstance(images, ImageObject): if isinstance(images, ImageObject):
images = [images] images = [images]
byte_images = [] byte_images = []
for image in images: for image in images:
with BytesIO() as buffer: with BytesIO() as buffer:
image.save(buffer, format="PNG") image.save(buffer, format="PNG")
buffer.seek(0) buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode('utf-8') byte_image = base64.b64encode(buffer.read()).decode("utf-8")
byte_images.append(byte_image) byte_images.append(byte_image)
return byte_images return byte_images

View File

@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
"qwen2_5_vl", "qwen2_5_vl",
] ]
# Copied from trl # Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None: def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules(): for module in model.modules():

View File

@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
return_tensors="pt", return_tensors="pt",
) )
input_ids=processed_input["input_ids"].tolist()[0] input_ids = processed_input["input_ids"].tolist()[0]
n_samples = self.gconfig.n_samples n_samples = self.gconfig.n_samples
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
completion_ids=resp.output_tokens, completion_ids=resp.output_tokens,
**data, **data,
) )
res = dict( res = dict(
# unsqueeze to add an additional batch dimension # unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0), input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0), loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0), pixel_values=processed_input["pixel_values"]
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0), .clone()
.detach()
.unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"]
.clone()
.detach()
.unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0), logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0), versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),

View File

@ -1,13 +1,15 @@
import os import os
import re import re
import sys import sys
import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats from arealite.utils.device import log_gpu_stats
@ -17,13 +19,12 @@ from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def extract_answer(pred_str, data_name, use_last_number=True): def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str) match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match: if match:
return match[-1] return match[-1]
return "" return ""
@ -55,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
break break
return final_answer return final_answer
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
is_thinking = "thinking" in completions.lower() is_thinking = "thinking" in completions.lower()
sol = extract_answer(completions, data_name="") # str number sol = extract_answer(completions, data_name="") # str number
ans =answer ans = answer
if sol is None: if sol is None:
return 0 return 0
if ans is None: if ans is None:
return 0 return 0
if sol.strip() == ans.strip(): if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}") print(f"completions: {completions}, answer: {answer}")
if is_thinking: if is_thinking:
return 1 return 1
else: else:
return 1 return 1
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()): if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
return 0.05 return 0.05
return 0 return 0
def main(args): def main(args):
os.environ["WANDB_API_KEY"]="" os.environ["WANDB_API_KEY"] = ""
wandb.init(project="clevr_70k") wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig) config, _ = load_expr_config(args, GRPOConfig)
@ -89,22 +93,22 @@ def main(args):
rank = int(os.getenv("RANK")) rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE")) world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path) processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset( train_dataset = get_custom_dataset(
path=config.train_dataset.path, path=config.train_dataset.path,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
split="train", split="train",
training_type="rl", training_type="rl",
processor=processor processor=processor,
) )
valid_dataset=get_custom_dataset( valid_dataset = get_custom_dataset(
path=config.valid_dataset.path, path=config.valid_dataset.path,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
split="test", split="test",
training_type="rl", training_type="rl",
processor=processor processor=processor,
) )
# Create dataset and dataloaders # Create dataset and dataloaders
train_dataloader = StatefulDataLoader( train_dataloader = StatefulDataLoader(
train_dataset, train_dataset,
@ -141,7 +145,7 @@ def main(args):
actor.initialize(None, ft_spec) actor.initialize(None, ft_spec)
ref = None ref = None
if config.actor.kl_ctl > 0 and config.ref is not None: if config.actor.kl_ctl > 0 and config.ref is not None:
ref =FSDPPPOActor(config=config.ref) ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec) ref.initialize(None, ft_spec)
# Create rollout workflow # Create rollout workflow
@ -199,7 +203,7 @@ def main(args):
if ref is not None: if ref is not None:
with stats_tracker.record_timing("ref_logp"): with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch) batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp") log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"): with stats_tracker.record_timing("compute_advantage"):
@ -211,8 +215,8 @@ def main(args):
stats_tracker.scope("grpo_actor"), stats_tracker.scope("grpo_actor"),
): ):
stats = actor.ppo_update(batch) stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']}) wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
actor.step_lr_scheduler() actor.step_lr_scheduler()
log_gpu_stats("ppo update") log_gpu_stats("ppo update")
@ -251,7 +255,14 @@ def main(args):
cnt += 1 cnt += 1
batch = eval_rollout.wait(cnt, timeout=None) batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device) rewards = batch["rewards"].float().to(actor.device)
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step}) wandb.log(
{
"eval_reward": rewards.mean().item(),
"epoch": epoch,
"step": step,
"global_step": global_step,
}
)
with stats_tracker.scope("grpo-eval"): with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator( stats_tracker.denominator(
n_seqs=torch.ones( n_seqs=torch.ones(
@ -280,6 +291,6 @@ def main(args):
actor.destroy() actor.destroy()
wandb.finish() wandb.finish()
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv[1:]) main(sys.argv[1:])

View File

@ -1,9 +1,11 @@
import os import os
import sys import sys
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator from arealite.utils.evaluator import Evaluator
@ -11,7 +13,6 @@ from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_processor_and_tokenizer from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker from realhf.base import stats_tracker
from arealite.dataset.__init__ import get_custom_dataset
def main_sft(): def main_sft():
@ -21,25 +22,25 @@ def main_sft():
rank = int(os.getenv("RANK")) rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE")) world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path) processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset( train_dataset = get_custom_dataset(
path=config.train_dataset.path, path=config.train_dataset.path,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
split="train", split="train",
training_type="sft", training_type="sft",
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
) )
valid_dataset=get_custom_dataset( valid_dataset = get_custom_dataset(
path=config.valid_dataset.path, path=config.valid_dataset.path,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
split="test", split="test",
training_type="sft", training_type="sft",
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
) )
# Create dataset and dataloaders # Create dataset and dataloaders
train_dataloader = StatefulDataLoader( train_dataloader = StatefulDataLoader(
train_dataset, train_dataset,

View File

@ -1,26 +1,39 @@
from typing import Any, Dict, Optional, Union import base64
import math import math
from PIL.Image import Image as ImageObject
import os import os
from io import BytesIO
from typing import Any, Dict, Optional, Union
from datasets import load_dataset from datasets import load_dataset
from datasets.distributed import split_dataset_by_node from datasets.distributed import split_dataset_by_node
import base64 from PIL.Image import Image as ImageObject
from io import BytesIO
def input_text(text:str):
def input_text(text: str):
return {"type": "input_text", "text": text} return {"type": "input_text", "text": text}
def input_image(base64_image: str): def input_image(base64_image: str):
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"} return {
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]: "type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
}
def build_raw_message(
sample: Dict[str, Any], base64_images: list[str]
) -> list[Dict[str, Any]]:
raw_message = [] raw_message = []
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()] problem_parts = [
part.strip() for part in sample["problem"].split("<image>") if part.strip()
]
insert_list = [] insert_list = []
for i, part in enumerate(problem_parts): for i, part in enumerate(problem_parts):
if i > 0 or sample["problem"].startswith("<image>"): if i > 0 or sample["problem"].startswith("<image>"):
insert_list.append("image") insert_list.append("image")
part = part.strip() part = part.strip()
if part: if part:
insert_list.append("text") insert_list.append("text")
image_index = 0 image_index = 0
text_index = 0 text_index = 0
@ -38,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
def encode_image(image_file): def encode_image(image_file):
return base64.b64encode(image_file).decode("utf-8") return base64.b64encode(image_file).decode("utf-8")
def convert_image( def convert_image(
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int] image: Union[Dict[str, Any], ImageObject, str],
min_pixels: Optional[int],
max_pixels: Optional[int],
) -> ImageObject: ) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels: if max_pixels is not None and (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height)) resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height)) image = image.resize((width, height))
if min_pixels is not None and (image.width * image.height) < min_pixels: if min_pixels is not None and (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height)) resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height)) image = image.resize((width, height))
if image.mode != "RGB": if image.mode != "RGB":
@ -57,32 +78,34 @@ def convert_image(
image.save(output, format="JPEG") image.save(output, format="JPEG")
return output.getvalue() return output.getvalue()
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size): def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
''' """
"clevr_count_70k": { "clevr_count_70k": {
"image_key": "images", "image_key": "images",
"question_key": "problem", "question_key": "problem",
"answer_key": "answer" "answer_key": "answer"
}, },
''' """
dataset = load_dataset(path=path, split=split) dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer tokenizer = processor.tokenizer
def process_example(example, idx): def process_example(example, idx):
# Add query_id column # Add query_id column
images = example["images"] images = example["images"]
if 'qwen' in processor.image_processor.image_processor_type.lower(): if "qwen" in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>" image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else: else:
image_token = processor.image_token if processor is not None else "<image>" image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token) example["problem"] = example["problem"].replace("<image>", image_token)
processed_images = [] processed_images = []
for image in images: for image in images:
processed_images.append(convert_image(image,113*113,336*336)) processed_images.append(convert_image(image, 113 * 113, 336 * 336))
example["images"] = processed_images example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
return example return example
dataset = dataset.map( dataset = dataset.map(
@ -91,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
) )
def _process(example): def _process(example):
text=example["seq"] text = example["seq"]
processed_input=processor( processed_input = processor(
text=[text], text=[text],
images=example["images"], images=example["images"],
padding=False, padding=False,
@ -101,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
return_attention_mask=False, return_attention_mask=False,
) )
example["input_ids"] =processed_input["input_ids"].squeeze(0) example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"] example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"] example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"]) answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token) loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
example["loss_mask"]=loss_mask answer_token
)
example["loss_mask"] = loss_mask
return example return example
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"]) dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset return dataset
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split) dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample): def process(sample):
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]] processed_images = [
if 'qwen' in processor.image_processor.image_processor_type.lower(): convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
image_token="<|vision_start|><|image_pad|><|vision_end|>" ]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else: else:
image_token = processor.image_token if processor is not None else "<image>" image_token = processor.image_token if processor is not None else "<image>"
system_prompt = { system_prompt = {
"role": "system", "role": "system",
"content": ( "content": (
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. " "Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
) ),
} }
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}] messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token),
}
]
messages.insert(0, system_prompt) messages.insert(0, system_prompt)
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images} return {"messages": messages, "images": processed_images}
dataset = dataset.map(process).remove_columns(["problem"]) dataset = dataset.map(process).remove_columns(["problem"])
return dataset return dataset

View File

@ -1,9 +1,11 @@
from datasets import load_dataset, Dataset from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node from datasets.distributed import split_dataset_by_node
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size): def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split) dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample): def process(sample):
seq_token = tokenizer.encode( seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token sample["question"] + sample["answer"] + tokenizer.eos_token
@ -15,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = dataset.map(process).remove_columns(["question", "answer"]) dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset return dataset
def get_gsm8k_rl_dataset(path,split, rank, world_size):
def get_gsm8k_rl_dataset(path, split, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split) dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample): def process(sample):
messages = [{"role": "user", "content": sample["question"]}] messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages} return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"]) dataset = dataset.map(process).remove_columns(["question"])
return dataset return dataset

View File

@ -68,6 +68,7 @@ def load_hf_tokenizer(
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer return tokenizer
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def load_hf_processor_and_tokenizer( def load_hf_processor_and_tokenizer(
model_name_or_path: str, model_name_or_path: str,