mirror of https://github.com/inclusionAI/AReaL
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AReaL into lcy/refactor
This commit is contained in:
commit
391bd85e44
|
@ -0,0 +1,95 @@
|
|||
---
|
||||
Language: Cpp
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlinesLeft: true
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: true
|
||||
AllowShortCaseLabelsOnASingleLine: true
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: true
|
||||
AllowShortLoopsOnASingleLine: true
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: false
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterClass: true
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: NonAssignment
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 100
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
BreakBeforeInheritanceComma: false
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IndentCaseLabels: true
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Right
|
||||
ReflowComments: true
|
||||
SortIncludes: false
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceAfterTemplateKeyword: false
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Auto
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
...
|
|
@ -0,0 +1,183 @@
|
|||
# Legacy codes
|
||||
.legacy/
|
||||
.data/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
trace_result/
|
||||
profile_result/
|
||||
|
||||
slurm_outs
|
||||
_data
|
||||
*.nfs*
|
||||
output
|
||||
logs
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
# dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# openai api key
|
||||
api_key.txt
|
||||
api_key.json
|
||||
|
||||
./*.sh
|
||||
*.png
|
||||
*.jpg
|
||||
*.pdf
|
||||
|
||||
.vscode/
|
|
@ -690,7 +690,7 @@ class ClusterSpecConfig:
|
|||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
path: str =field(
|
||||
path: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
|
||||
|
|
|
@ -47,14 +47,17 @@ class LLMResponse:
|
|||
@property
|
||||
def output_len(self) -> int:
|
||||
return len(self.output_tokens)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMRequest(LLMRequest):
|
||||
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
|
||||
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResponse(LLMResponse):
|
||||
input_images: List[ImageObject|str] = field(default_factory=list)
|
||||
input_images: List[ImageObject | str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuneSpec:
|
||||
|
@ -142,6 +145,7 @@ class AllocationMode:
|
|||
raise ValueError(
|
||||
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
||||
pattern = re.compile(
|
||||
|
|
|
@ -4,36 +4,39 @@ import transformers
|
|||
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||
|
||||
|
||||
def get_custom_dataset(
|
||||
path: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
training_type: str= "sft",
|
||||
training_type: str = "sft",
|
||||
split: Optional[str] = None,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
processor: Optional[transformers.AutoProcessor] = None,
|
||||
):
|
||||
):
|
||||
|
||||
|
||||
if "gsm8k" in path and training_type == "sft":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
||||
elif "gsm8k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "sft":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
||||
|
||||
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
|
||||
f"Supported datasets are: {VALID_DATASETS}. "
|
||||
)
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ from tensordict import TensorDict
|
|||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoProcessor,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
|
@ -31,8 +31,8 @@ from arealite.utils.data import (
|
|||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||
from arealite.utils.model import disable_dropout_in_model,VALID_VISION_MODELS
|
||||
from realhf.api.core.data_api import load_hf_tokenizer, load_hf_processor_and_tokenizer
|
||||
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
|
||||
from realhf.base import constants, logging
|
||||
|
||||
logger = logging.getLogger("Base HF Engine")
|
||||
|
@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
|
|||
self.own_global_group = False
|
||||
self._parallelism_group: dist.ProcessGroup
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
|
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
|
|||
dtype = getattr(torch, self.config.dtype)
|
||||
|
||||
if self.is_vision_model:
|
||||
dtype = torch.bfloat16
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
|
||||
dtype = torch.bfloat16
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||
self.config.path
|
||||
)
|
||||
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
|
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
|
|||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
|
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
|
||||
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if self.is_vision_model:
|
||||
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
|
||||
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||
)
|
||||
assert (
|
||||
"pixel_values" in input_ and "image_grid_thw" in input_
|
||||
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
|
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
|
|||
|
||||
loss *= loss_scale
|
||||
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.optimizer_config.gradient_clipping,
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import time
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
|
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
|
@ -26,6 +26,7 @@ from arealite.utils.fsdp import (
|
|||
fsdp2_load_full_state_dict,
|
||||
)
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
@ -48,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
self.create_process_group()
|
||||
self.create_device_model()
|
||||
|
||||
|
||||
# Wrap with FSDP2
|
||||
# Simple auto wrap policy
|
||||
|
@ -100,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
|
|||
self.load_optimizer_state(meta.path)
|
||||
|
||||
def _save_model_to_hf(
|
||||
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[PreTrainedTokenizerFast],
|
||||
processor: Optional[AutoProcessor],
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
|
@ -146,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
|
||||
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
||||
# dist.barrier() are called when _save_model_to_hf finished
|
||||
if dist.get_rank() == 0:
|
||||
update_name = names.update_weights_from_disk(
|
||||
|
@ -239,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
|
||||
# NOTE: grad norm clip function is different
|
||||
|
||||
|
|
|
@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
|
|||
return self.actor.ppo_update(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
|
|
|
@ -5,7 +5,6 @@ from tensordict import TensorDict
|
|||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
# from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
|
||||
from arealite.utils.functional import gather_logprobs
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
@ -42,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
|
|||
def evaluate_lm(self, data):
|
||||
return self.lm_engine.evaluate_lm(data)
|
||||
|
||||
|
||||
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
|
@ -50,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
|||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(loss_mask, logprobs, 0)
|
||||
|
||||
|
||||
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
|
|
|
@ -23,9 +23,9 @@ from arealite.api.io_struct import (
|
|||
LLMRequest,
|
||||
LLMResponse,
|
||||
RolloutStat,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
WeightUpdateMeta,
|
||||
VLMRequest,
|
||||
VLMResponse
|
||||
)
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||
|
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
|
||||
async def agenerate(
|
||||
self, req: LLMRequest | VLMRequest
|
||||
) -> LLMResponse | VLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
|
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
|
||||
if isinstance(req, VLMRequest):
|
||||
response = VLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
input_images=req.image_data,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
input_tokens=req.input_ids,
|
||||
input_images=req.image_data,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
else:
|
||||
response=LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
response = LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
return response
|
||||
|
||||
def update_weights(self, meta):
|
||||
|
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
|
||||
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
|
|
|
@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
|
|||
return TensorDict()
|
||||
skip_keys = {"pixel_values", "image_grid_thw"}
|
||||
max_length = max(
|
||||
len(seq) for item in sequence_list
|
||||
for key, seq in item.items()
|
||||
len(seq)
|
||||
for item in sequence_list
|
||||
for key, seq in item.items()
|
||||
if key not in skip_keys
|
||||
)
|
||||
result = {}
|
||||
|
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
|
|||
x = item[key]
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.tensor(x)
|
||||
padded_x=torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded_x = torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded.append(padded_x)
|
||||
result[key] = torch.stack(padded)
|
||||
attention_mask = [
|
||||
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
|
||||
+ [0]
|
||||
* (
|
||||
max_length
|
||||
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
)
|
||||
for item in sequence_list
|
||||
]
|
||||
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
||||
|
@ -139,7 +144,7 @@ def concat_padded_tensors(
|
|||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if key == "pixel_values" or key== "image_grid_thw":
|
||||
if key == "pixel_values" or key == "image_grid_thw":
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
if current_length < max_length:
|
||||
|
@ -150,7 +155,7 @@ def concat_padded_tensors(
|
|||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
|
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key=="image_grid_thw" or key=="pixel_values":
|
||||
if key == "image_grid_thw" or key == "pixel_values":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
|
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
|
||||
for group_index in group_indices:
|
||||
group_pixel_values = [pixel_values[i] for i in group_index]
|
||||
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
|
||||
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
|
||||
|
||||
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
|
||||
pixel_values_split.append(torch.stack(group_pixel_values))
|
||||
|
|
|
@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
|
|||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||
|
||||
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
|
|
@ -7,20 +7,17 @@ from typing import List
|
|||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
||||
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
|
||||
|
||||
if isinstance(images, ImageObject):
|
||||
images = [images]
|
||||
|
||||
|
||||
byte_images = []
|
||||
for image in images:
|
||||
with BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
buffer.seek(0)
|
||||
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
buffer.seek(0)
|
||||
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
byte_images.append(byte_image)
|
||||
|
||||
|
||||
return byte_images
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
|
|||
"qwen2_5_vl",
|
||||
]
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
|
|
|
@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids=processed_input["input_ids"].tolist()[0]
|
||||
input_ids = processed_input["input_ids"].tolist()[0]
|
||||
|
||||
n_samples = self.gconfig.n_samples
|
||||
|
||||
|
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
|||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
|
@ -17,13 +19,12 @@ from arealite.utils.stats_logger import StatsLogger
|
|||
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
@ -55,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
|
|||
break
|
||||
return final_answer
|
||||
|
||||
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
is_thinking = "thinking" in completions.lower()
|
||||
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans =answer
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
if is_thinking:
|
||||
return 1
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||
return 0.05
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
os.environ["WANDB_API_KEY"]=""
|
||||
os.environ["WANDB_API_KEY"] = ""
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
|
@ -89,22 +93,22 @@ def main(args):
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset=get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="rl",
|
||||
processor=processor
|
||||
)
|
||||
valid_dataset=get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="rl",
|
||||
processor=processor
|
||||
)
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="rl",
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="rl",
|
||||
processor=processor,
|
||||
)
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
|
@ -141,7 +145,7 @@ def main(args):
|
|||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref =FSDPPPOActor(config=config.ref)
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
|
@ -199,7 +203,7 @@ def main(args):
|
|||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
||||
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
|
@ -211,8 +215,8 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
|
||||
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
|
@ -251,7 +255,14 @@ def main(args):
|
|||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"].float().to(actor.device)
|
||||
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
|
||||
wandb.log(
|
||||
{
|
||||
"eval_reward": rewards.mean().item(),
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
}
|
||||
)
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
|
@ -280,6 +291,6 @@ def main(args):
|
|||
actor.destroy()
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
|
@ -11,7 +13,6 @@ from arealite.utils.saver import Saver
|
|||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
|
||||
|
||||
def main_sft():
|
||||
|
@ -21,25 +22,25 @@ def main_sft():
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset=get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset=get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
|
|
|
@ -1,26 +1,39 @@
|
|||
from typing import Any, Dict, Optional, Union
|
||||
import base64
|
||||
import math
|
||||
from PIL.Image import Image as ImageObject
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
def input_text(text:str):
|
||||
|
||||
def input_text(text: str):
|
||||
return {"type": "input_text", "text": text}
|
||||
|
||||
|
||||
def input_image(base64_image: str):
|
||||
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
|
||||
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
|
||||
|
||||
return {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
||||
}
|
||||
|
||||
|
||||
def build_raw_message(
|
||||
sample: Dict[str, Any], base64_images: list[str]
|
||||
) -> list[Dict[str, Any]]:
|
||||
|
||||
raw_message = []
|
||||
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
|
||||
problem_parts = [
|
||||
part.strip() for part in sample["problem"].split("<image>") if part.strip()
|
||||
]
|
||||
insert_list = []
|
||||
for i, part in enumerate(problem_parts):
|
||||
if i > 0 or sample["problem"].startswith("<image>"):
|
||||
if i > 0 or sample["problem"].startswith("<image>"):
|
||||
insert_list.append("image")
|
||||
part = part.strip()
|
||||
if part:
|
||||
part = part.strip()
|
||||
if part:
|
||||
insert_list.append("text")
|
||||
image_index = 0
|
||||
text_index = 0
|
||||
|
@ -38,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
|
|||
|
||||
def encode_image(image_file):
|
||||
return base64.b64encode(image_file).decode("utf-8")
|
||||
|
||||
|
||||
def convert_image(
|
||||
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
|
||||
image: Union[Dict[str, Any], ImageObject, str],
|
||||
min_pixels: Optional[int],
|
||||
max_pixels: Optional[int],
|
||||
) -> ImageObject:
|
||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if min_pixels is not None and (image.width * image.height) < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
|
@ -57,32 +78,34 @@ def convert_image(
|
|||
image.save(output, format="JPEG")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||
'''
|
||||
"""
|
||||
"clevr_count_70k": {
|
||||
"image_key": "images",
|
||||
"question_key": "problem",
|
||||
"answer_key": "answer"
|
||||
},
|
||||
'''
|
||||
"""
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
def process_example(example, idx):
|
||||
# Add query_id column
|
||||
images = example["images"]
|
||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
example["problem"] = example["problem"].replace("<image>", image_token)
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(convert_image(image,113*113,336*336))
|
||||
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
|
||||
example["images"] = processed_images
|
||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||
|
||||
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
|
@ -91,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
)
|
||||
|
||||
def _process(example):
|
||||
text=example["seq"]
|
||||
processed_input=processor(
|
||||
text = example["seq"]
|
||||
processed_input = processor(
|
||||
text=[text],
|
||||
images=example["images"],
|
||||
padding=False,
|
||||
|
@ -101,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
example["input_ids"] =processed_input["input_ids"].squeeze(0)
|
||||
example["input_ids"] = processed_input["input_ids"].squeeze(0)
|
||||
example["pixel_values"] = processed_input["pixel_values"]
|
||||
example["image_grid_thw"] = processed_input["image_grid_thw"]
|
||||
answer_token = tokenizer.encode(example["answer"])
|
||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
|
||||
example["loss_mask"]=loss_mask
|
||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
|
||||
answer_token
|
||||
)
|
||||
example["loss_mask"] = loss_mask
|
||||
return example
|
||||
|
||||
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
|
||||
dataset = dataset.map(
|
||||
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
|
||||
|
||||
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
|
||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
||||
processed_images = [
|
||||
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
|
||||
]
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
system_prompt = {
|
||||
"role": "system",
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sample["problem"].replace("<image>", image_token),
|
||||
}
|
||||
]
|
||||
messages.insert(0, system_prompt)
|
||||
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
messages = processor.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
return {"messages": messages, "images": processed_images}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["problem"])
|
||||
return dataset
|
||||
return dataset
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from datasets import load_dataset, Dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
|
||||
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
|
@ -15,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
|||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
def get_gsm8k_rl_dataset(path,split, rank, world_size):
|
||||
|
||||
def get_gsm8k_rl_dataset(path, split, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
return dataset
|
||||
|
|
|
@ -68,6 +68,7 @@ def load_hf_tokenizer(
|
|||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def load_hf_processor_and_tokenizer(
|
||||
model_name_or_path: str,
|
||||
|
|
Loading…
Reference in New Issue