mirror of https://github.com/inclusionAI/AReaL
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AReaL into lcy/refactor
This commit is contained in:
commit
391bd85e44
|
@ -0,0 +1,95 @@
|
||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
AccessModifierOffset: -1
|
||||||
|
AlignAfterOpenBracket: Align
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
AlignEscapedNewlinesLeft: true
|
||||||
|
AlignOperands: true
|
||||||
|
AlignTrailingComments: true
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: true
|
||||||
|
AllowShortBlocksOnASingleLine: true
|
||||||
|
AllowShortCaseLabelsOnASingleLine: true
|
||||||
|
AllowShortFunctionsOnASingleLine: All
|
||||||
|
AllowShortIfStatementsOnASingleLine: true
|
||||||
|
AllowShortLoopsOnASingleLine: true
|
||||||
|
AlwaysBreakAfterDefinitionReturnType: None
|
||||||
|
AlwaysBreakAfterReturnType: None
|
||||||
|
AlwaysBreakBeforeMultilineStrings: false
|
||||||
|
AlwaysBreakTemplateDeclarations: true
|
||||||
|
BinPackArguments: true
|
||||||
|
BinPackParameters: true
|
||||||
|
BraceWrapping:
|
||||||
|
AfterClass: true
|
||||||
|
AfterControlStatement: false
|
||||||
|
AfterEnum: false
|
||||||
|
AfterFunction: false
|
||||||
|
AfterNamespace: false
|
||||||
|
AfterObjCDeclaration: false
|
||||||
|
AfterStruct: false
|
||||||
|
AfterUnion: false
|
||||||
|
BeforeCatch: false
|
||||||
|
BeforeElse: false
|
||||||
|
IndentBraces: false
|
||||||
|
BreakBeforeBinaryOperators: NonAssignment
|
||||||
|
BreakBeforeBraces: Attach
|
||||||
|
BreakBeforeTernaryOperators: true
|
||||||
|
BreakConstructorInitializersBeforeComma: false
|
||||||
|
BreakAfterJavaFieldAnnotations: false
|
||||||
|
BreakStringLiterals: true
|
||||||
|
ColumnLimit: 100
|
||||||
|
CommentPragmas: '^ IWYU pragma:'
|
||||||
|
BreakBeforeInheritanceComma: false
|
||||||
|
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||||
|
ConstructorInitializerIndentWidth: 4
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
Cpp11BracedListStyle: true
|
||||||
|
DisableFormat: false
|
||||||
|
ExperimentalAutoDetectBinPacking: false
|
||||||
|
FixNamespaceComments: true
|
||||||
|
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^<.*\.h>'
|
||||||
|
Priority: 1
|
||||||
|
- Regex: '^<.*'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 3
|
||||||
|
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||||
|
IndentCaseLabels: true
|
||||||
|
IndentWidth: 2
|
||||||
|
IndentWrappedFunctionNames: false
|
||||||
|
JavaScriptQuotes: Leave
|
||||||
|
JavaScriptWrapImports: true
|
||||||
|
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||||
|
MacroBlockBegin: ''
|
||||||
|
MacroBlockEnd: ''
|
||||||
|
MaxEmptyLinesToKeep: 1
|
||||||
|
NamespaceIndentation: None
|
||||||
|
ObjCBlockIndentWidth: 2
|
||||||
|
ObjCSpaceAfterProperty: false
|
||||||
|
ObjCSpaceBeforeProtocolList: false
|
||||||
|
PenaltyBreakBeforeFirstCallParameter: 1
|
||||||
|
PenaltyBreakComment: 300
|
||||||
|
PenaltyBreakFirstLessLess: 120
|
||||||
|
PenaltyBreakString: 1000
|
||||||
|
PenaltyExcessCharacter: 1000000
|
||||||
|
PenaltyReturnTypeOnItsOwnLine: 200
|
||||||
|
PointerAlignment: Right
|
||||||
|
ReflowComments: true
|
||||||
|
SortIncludes: false
|
||||||
|
SpaceAfterCStyleCast: false
|
||||||
|
SpaceAfterTemplateKeyword: false
|
||||||
|
SpaceBeforeAssignmentOperators: true
|
||||||
|
SpaceBeforeParens: ControlStatements
|
||||||
|
SpaceInEmptyParentheses: false
|
||||||
|
SpacesBeforeTrailingComments: 2
|
||||||
|
SpacesInAngles: false
|
||||||
|
SpacesInContainerLiterals: true
|
||||||
|
SpacesInCStyleCastParentheses: false
|
||||||
|
SpacesInParentheses: false
|
||||||
|
SpacesInSquareBrackets: false
|
||||||
|
Standard: Auto
|
||||||
|
TabWidth: 8
|
||||||
|
UseTab: Never
|
||||||
|
...
|
|
@ -0,0 +1,183 @@
|
||||||
|
# Legacy codes
|
||||||
|
.legacy/
|
||||||
|
.data/
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
trace_result/
|
||||||
|
profile_result/
|
||||||
|
|
||||||
|
slurm_outs
|
||||||
|
_data
|
||||||
|
*.nfs*
|
||||||
|
output
|
||||||
|
logs
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
# dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# openai api key
|
||||||
|
api_key.txt
|
||||||
|
api_key.json
|
||||||
|
|
||||||
|
./*.sh
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.pdf
|
||||||
|
|
||||||
|
.vscode/
|
|
@ -690,7 +690,7 @@ class ClusterSpecConfig:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
path: str =field(
|
path: str = field(
|
||||||
default=MISSING,
|
default=MISSING,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
|
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
|
||||||
|
|
|
@ -47,14 +47,17 @@ class LLMResponse:
|
||||||
@property
|
@property
|
||||||
def output_len(self) -> int:
|
def output_len(self) -> int:
|
||||||
return len(self.output_tokens)
|
return len(self.output_tokens)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VLMRequest(LLMRequest):
|
class VLMRequest(LLMRequest):
|
||||||
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
|
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VLMResponse(LLMResponse):
|
class VLMResponse(LLMResponse):
|
||||||
input_images: List[ImageObject|str] = field(default_factory=list)
|
input_images: List[ImageObject | str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuneSpec:
|
class FinetuneSpec:
|
||||||
|
@ -142,6 +145,7 @@ class AllocationMode:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
|
|
|
@ -4,36 +4,39 @@ import transformers
|
||||||
|
|
||||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||||
|
|
||||||
|
|
||||||
def get_custom_dataset(
|
def get_custom_dataset(
|
||||||
path: str,
|
path: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
training_type: str= "sft",
|
training_type: str = "sft",
|
||||||
split: Optional[str] = None,
|
split: Optional[str] = None,
|
||||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||||
processor: Optional[transformers.AutoProcessor] = None,
|
processor: Optional[transformers.AutoProcessor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
if "gsm8k" in path and training_type == "sft":
|
if "gsm8k" in path and training_type == "sft":
|
||||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||||
|
|
||||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
||||||
elif "gsm8k" in path and training_type == "rl":
|
elif "gsm8k" in path and training_type == "rl":
|
||||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||||
|
|
||||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||||
elif "clevr_count_70k" in path and training_type == "sft":
|
elif "clevr_count_70k" in path and training_type == "sft":
|
||||||
from examples.arealite.dataset.clevr_count_70k import (
|
from examples.arealite.dataset.clevr_count_70k import (
|
||||||
get_clevr_count_70k_sft_dataset,
|
get_clevr_count_70k_sft_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||||
elif "clevr_count_70k" in path and training_type == "rl":
|
elif "clevr_count_70k" in path and training_type == "rl":
|
||||||
from examples.arealite.dataset.clevr_count_70k import (
|
from examples.arealite.dataset.clevr_count_70k import (
|
||||||
get_clevr_count_70k_rl_dataset,
|
get_clevr_count_70k_rl_dataset,
|
||||||
)
|
)
|
||||||
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
|
||||||
|
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
|
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
|
||||||
f"Supported datasets are: {VALID_DATASETS}. "
|
f"Supported datasets are: {VALID_DATASETS}. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,8 @@ from tensordict import TensorDict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoProcessor,
|
|
||||||
AutoModelForImageTextToText,
|
AutoModelForImageTextToText,
|
||||||
|
AutoProcessor,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
get_constant_schedule_with_warmup,
|
get_constant_schedule_with_warmup,
|
||||||
|
@ -31,8 +31,8 @@ from arealite.utils.data import (
|
||||||
unsqueeze_mb_list,
|
unsqueeze_mb_list,
|
||||||
)
|
)
|
||||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||||
from arealite.utils.model import disable_dropout_in_model,VALID_VISION_MODELS
|
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer, load_hf_processor_and_tokenizer
|
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
|
||||||
from realhf.base import constants, logging
|
from realhf.base import constants, logging
|
||||||
|
|
||||||
logger = logging.getLogger("Base HF Engine")
|
logger = logging.getLogger("Base HF Engine")
|
||||||
|
@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
self.own_global_group = False
|
self.own_global_group = False
|
||||||
self._parallelism_group: dist.ProcessGroup
|
self._parallelism_group: dist.ProcessGroup
|
||||||
self.weight_update_group_initialized = False
|
self.weight_update_group_initialized = False
|
||||||
|
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path=self.config.path,
|
pretrained_model_name_or_path=self.config.path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
|
||||||
dtype = getattr(torch, self.config.dtype)
|
dtype = getattr(torch, self.config.dtype)
|
||||||
|
|
||||||
if self.is_vision_model:
|
if self.is_vision_model:
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
|
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||||
|
self.config.path
|
||||||
|
)
|
||||||
|
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
|
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
)
|
)
|
||||||
if self.config.disable_dropout:
|
if self.config.disable_dropout:
|
||||||
disable_dropout_in_model(model)
|
disable_dropout_in_model(model)
|
||||||
|
|
||||||
if self.config.gradient_checkpointing:
|
if self.config.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
|
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
|
||||||
assert self.lr_scheduler is not None
|
assert self.lr_scheduler is not None
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||||
assert "attention_mask" in input_ and "input_ids" in input_
|
assert "attention_mask" in input_ and "input_ids" in input_
|
||||||
if self.is_vision_model:
|
if self.is_vision_model:
|
||||||
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
|
assert (
|
||||||
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
"pixel_values" in input_ and "image_grid_thw" in input_
|
||||||
)
|
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||||
|
|
||||||
if isinstance(input_, dict):
|
if isinstance(input_, dict):
|
||||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||||
|
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
|
||||||
|
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.model.parameters(),
|
self.model.parameters(),
|
||||||
self.optimizer_config.gradient_clipping,
|
self.optimizer_config.gradient_clipping,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
@ -11,7 +11,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
get_model_state_dict,
|
get_model_state_dict,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
|
@ -26,6 +26,7 @@ from arealite.utils.fsdp import (
|
||||||
fsdp2_load_full_state_dict,
|
fsdp2_load_full_state_dict,
|
||||||
)
|
)
|
||||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||||
|
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
|
|
||||||
logger = logging.getLogger("FSDPEngine")
|
logger = logging.getLogger("FSDPEngine")
|
||||||
|
@ -48,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
|
||||||
|
|
||||||
self.create_process_group()
|
self.create_process_group()
|
||||||
self.create_device_model()
|
self.create_device_model()
|
||||||
|
|
||||||
|
|
||||||
# Wrap with FSDP2
|
# Wrap with FSDP2
|
||||||
# Simple auto wrap policy
|
# Simple auto wrap policy
|
||||||
|
@ -100,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
|
||||||
self.load_optimizer_state(meta.path)
|
self.load_optimizer_state(meta.path)
|
||||||
|
|
||||||
def _save_model_to_hf(
|
def _save_model_to_hf(
|
||||||
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
|
self,
|
||||||
|
path: str,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerFast],
|
||||||
|
processor: Optional[AutoProcessor],
|
||||||
):
|
):
|
||||||
"""Save model in HuggingFace format."""
|
"""Save model in HuggingFace format."""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
@ -146,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif meta.type == "disk":
|
elif meta.type == "disk":
|
||||||
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
|
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
||||||
# dist.barrier() are called when _save_model_to_hf finished
|
# dist.barrier() are called when _save_model_to_hf finished
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
update_name = names.update_weights_from_disk(
|
update_name = names.update_weights_from_disk(
|
||||||
|
@ -239,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
|
||||||
|
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
# NOTE: grad norm clip function is different
|
# NOTE: grad norm clip function is different
|
||||||
|
|
||||||
|
|
|
@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
|
||||||
return self.actor.ppo_update(*args, **kwargs)
|
return self.actor.ppo_update(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss_fn(
|
def grpo_loss_fn(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
input_data: Dict,
|
input_data: Dict,
|
||||||
|
|
|
@ -5,7 +5,6 @@ from tensordict import TensorDict
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import TrainEngine
|
from arealite.api.engine_api import TrainEngine
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
# from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
|
|
||||||
from arealite.utils.functional import gather_logprobs
|
from arealite.utils.functional import gather_logprobs
|
||||||
from realhf.base import stats_tracker
|
from realhf.base import stats_tracker
|
||||||
|
|
||||||
|
@ -42,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
|
||||||
def evaluate_lm(self, data):
|
def evaluate_lm(self, data):
|
||||||
return self.lm_engine.evaluate_lm(data)
|
return self.lm_engine.evaluate_lm(data)
|
||||||
|
|
||||||
|
|
||||||
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
||||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||||
|
@ -50,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
||||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||||
logprobs = torch.where(loss_mask, logprobs, 0)
|
logprobs = torch.where(loss_mask, logprobs, 0)
|
||||||
|
|
||||||
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
seqlogp = torch.zeros(
|
seqlogp = torch.zeros(
|
||||||
|
|
|
@ -23,9 +23,9 @@ from arealite.api.io_struct import (
|
||||||
LLMRequest,
|
LLMRequest,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
RolloutStat,
|
RolloutStat,
|
||||||
|
VLMRequest,
|
||||||
|
VLMResponse,
|
||||||
WeightUpdateMeta,
|
WeightUpdateMeta,
|
||||||
VLMRequest,
|
|
||||||
VLMResponse
|
|
||||||
)
|
)
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||||
|
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
return server
|
return server
|
||||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||||
|
|
||||||
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
|
async def agenerate(
|
||||||
|
self, req: LLMRequest | VLMRequest
|
||||||
|
) -> LLMResponse | VLMResponse:
|
||||||
"""Async version of generate using aiohttp."""
|
"""Async version of generate using aiohttp."""
|
||||||
# Prepare request payload
|
# Prepare request payload
|
||||||
gconfig = req.gconfig
|
gconfig = req.gconfig
|
||||||
|
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
sample_params["max_new_tokens"] -= len(output_tokens)
|
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||||
|
|
||||||
latency = time.perf_counter() - start_time
|
latency = time.perf_counter() - start_time
|
||||||
|
|
||||||
if isinstance(req, VLMRequest):
|
if isinstance(req, VLMRequest):
|
||||||
response = VLMResponse(
|
response = VLMResponse(
|
||||||
input_tokens=req.input_ids,
|
input_tokens=req.input_ids,
|
||||||
input_images=req.image_data,
|
input_images=req.image_data,
|
||||||
output_tokens=accumulated_output_tokens,
|
output_tokens=accumulated_output_tokens,
|
||||||
output_logprobs=accumulated_output_logprobs,
|
output_logprobs=accumulated_output_logprobs,
|
||||||
output_versions=accumulated_versions,
|
output_versions=accumulated_versions,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
latency=latency,
|
latency=latency,
|
||||||
ttft=latency, # Simplified for non-streaming
|
ttft=latency, # Simplified for non-streaming
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response=LLMResponse(
|
response = LLMResponse(
|
||||||
input_tokens=req.input_ids,
|
input_tokens=req.input_ids,
|
||||||
output_tokens=accumulated_output_tokens,
|
output_tokens=accumulated_output_tokens,
|
||||||
output_logprobs=accumulated_output_logprobs,
|
output_logprobs=accumulated_output_logprobs,
|
||||||
output_versions=accumulated_versions,
|
output_versions=accumulated_versions,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
latency=latency,
|
latency=latency,
|
||||||
ttft=latency, # Simplified for non-streaming
|
ttft=latency, # Simplified for non-streaming
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def update_weights(self, meta):
|
def update_weights(self, meta):
|
||||||
|
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
data = next(self.data_generator)
|
data = next(self.data_generator)
|
||||||
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.data_generator = iter(dataloader)
|
self.data_generator = iter(dataloader)
|
||||||
data = next(self.data_generator)
|
data = next(self.data_generator)
|
||||||
|
|
|
@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
|
||||||
return TensorDict()
|
return TensorDict()
|
||||||
skip_keys = {"pixel_values", "image_grid_thw"}
|
skip_keys = {"pixel_values", "image_grid_thw"}
|
||||||
max_length = max(
|
max_length = max(
|
||||||
len(seq) for item in sequence_list
|
len(seq)
|
||||||
for key, seq in item.items()
|
for item in sequence_list
|
||||||
|
for key, seq in item.items()
|
||||||
if key not in skip_keys
|
if key not in skip_keys
|
||||||
)
|
)
|
||||||
result = {}
|
result = {}
|
||||||
|
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
|
||||||
x = item[key]
|
x = item[key]
|
||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
x = torch.tensor(x)
|
x = torch.tensor(x)
|
||||||
padded_x=torch.nn.functional.pad(
|
padded_x = torch.nn.functional.pad(
|
||||||
x, (0, max_length - len(item[key])), value=pad_value
|
x, (0, max_length - len(item[key])), value=pad_value
|
||||||
)
|
)
|
||||||
padded.append(padded_x)
|
padded.append(padded_x)
|
||||||
result[key] = torch.stack(padded)
|
result[key] = torch.stack(padded)
|
||||||
attention_mask = [
|
attention_mask = [
|
||||||
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||||
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
|
+ [0]
|
||||||
|
* (
|
||||||
|
max_length
|
||||||
|
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||||
|
)
|
||||||
for item in sequence_list
|
for item in sequence_list
|
||||||
]
|
]
|
||||||
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
||||||
|
@ -139,7 +144,7 @@ def concat_padded_tensors(
|
||||||
tensors_to_concat.append(tensor)
|
tensors_to_concat.append(tensor)
|
||||||
continue
|
continue
|
||||||
current_length = tensor.shape[1]
|
current_length = tensor.shape[1]
|
||||||
if key == "pixel_values" or key== "image_grid_thw":
|
if key == "pixel_values" or key == "image_grid_thw":
|
||||||
tensors_to_concat.append(tensor)
|
tensors_to_concat.append(tensor)
|
||||||
continue
|
continue
|
||||||
if current_length < max_length:
|
if current_length < max_length:
|
||||||
|
@ -150,7 +155,7 @@ def concat_padded_tensors(
|
||||||
padding = torch.zeros(
|
padding = torch.zeros(
|
||||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Pad feature tensors with pad_value
|
# Pad feature tensors with pad_value
|
||||||
padding = torch.full(
|
padding = torch.full(
|
||||||
|
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
|
||||||
to_split = {}
|
to_split = {}
|
||||||
not_to_split = {}
|
not_to_split = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=="image_grid_thw" or key=="pixel_values":
|
if key == "image_grid_thw" or key == "pixel_values":
|
||||||
continue
|
continue
|
||||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||||
not_to_split[key] = value
|
not_to_split[key] = value
|
||||||
|
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
|
||||||
|
|
||||||
for group_index in group_indices:
|
for group_index in group_indices:
|
||||||
group_pixel_values = [pixel_values[i] for i in group_index]
|
group_pixel_values = [pixel_values[i] for i in group_index]
|
||||||
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
|
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
|
||||||
|
|
||||||
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
|
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
|
||||||
pixel_values_split.append(torch.stack(group_pixel_values))
|
pixel_values_split.append(torch.stack(group_pixel_values))
|
||||||
|
|
|
@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
|
||||||
grads = [p.grad for p in parameters if p.grad is not None]
|
grads = [p.grad for p in parameters if p.grad is not None]
|
||||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||||
|
|
||||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||||
return total_norm
|
return total_norm
|
||||||
|
|
||||||
|
|
|
@ -7,20 +7,17 @@ from typing import List
|
||||||
from PIL.Image import Image as ImageObject
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
|
||||||
|
|
||||||
if isinstance(images, ImageObject):
|
if isinstance(images, ImageObject):
|
||||||
images = [images]
|
images = [images]
|
||||||
|
|
||||||
byte_images = []
|
byte_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
with BytesIO() as buffer:
|
with BytesIO() as buffer:
|
||||||
image.save(buffer, format="PNG")
|
image.save(buffer, format="PNG")
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
|
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
|
||||||
byte_images.append(byte_image)
|
byte_images.append(byte_image)
|
||||||
|
|
||||||
return byte_images
|
return byte_images
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
|
||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Copied from trl
|
# Copied from trl
|
||||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
|
|
|
@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids=processed_input["input_ids"].tolist()[0]
|
input_ids = processed_input["input_ids"].tolist()[0]
|
||||||
|
|
||||||
n_samples = self.gconfig.n_samples
|
n_samples = self.gconfig.n_samples
|
||||||
|
|
||||||
|
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
||||||
completion_ids=resp.output_tokens,
|
completion_ids=resp.output_tokens,
|
||||||
**data,
|
**data,
|
||||||
)
|
)
|
||||||
|
|
||||||
res = dict(
|
res = dict(
|
||||||
# unsqueeze to add an additional batch dimension
|
# unsqueeze to add an additional batch dimension
|
||||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||||
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
|
pixel_values=processed_input["pixel_values"]
|
||||||
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
|
.clone()
|
||||||
|
.detach()
|
||||||
|
.unsqueeze(0),
|
||||||
|
image_grid_thw=processed_input["image_grid_thw"]
|
||||||
|
.clone()
|
||||||
|
.detach()
|
||||||
|
.unsqueeze(0),
|
||||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||||
versions=torch.tensor(versions).unsqueeze(0),
|
versions=torch.tensor(versions).unsqueeze(0),
|
||||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import wandb
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import wandb
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||||
|
from arealite.dataset.__init__ import get_custom_dataset
|
||||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
from arealite.utils.device import log_gpu_stats
|
from arealite.utils.device import log_gpu_stats
|
||||||
|
@ -17,13 +19,12 @@ from arealite.utils.stats_logger import StatsLogger
|
||||||
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
|
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
|
||||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||||
from realhf.base import stats_tracker
|
from realhf.base import stats_tracker
|
||||||
from arealite.dataset.__init__ import get_custom_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||||
if match:
|
if match:
|
||||||
return match[-1]
|
return match[-1]
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -55,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
|
||||||
break
|
break
|
||||||
return final_answer
|
return final_answer
|
||||||
|
|
||||||
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
|
||||||
|
def clevr_count_70k_reward_fn(
|
||||||
|
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||||
|
):
|
||||||
is_thinking = "thinking" in completions.lower()
|
is_thinking = "thinking" in completions.lower()
|
||||||
|
|
||||||
sol = extract_answer(completions, data_name="") # str number
|
sol = extract_answer(completions, data_name="") # str number
|
||||||
ans =answer
|
ans = answer
|
||||||
|
|
||||||
if sol is None:
|
if sol is None:
|
||||||
return 0
|
return 0
|
||||||
if ans is None:
|
if ans is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if sol.strip() == ans.strip():
|
if sol.strip() == ans.strip():
|
||||||
print(f"completions: {completions}, answer: {answer}")
|
print(f"completions: {completions}, answer: {answer}")
|
||||||
if is_thinking:
|
if is_thinking:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||||
return 0.05
|
return 0.05
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
os.environ["WANDB_API_KEY"]=""
|
os.environ["WANDB_API_KEY"] = ""
|
||||||
wandb.init(project="clevr_70k")
|
wandb.init(project="clevr_70k")
|
||||||
|
|
||||||
config, _ = load_expr_config(args, GRPOConfig)
|
config, _ = load_expr_config(args, GRPOConfig)
|
||||||
|
@ -89,22 +93,22 @@ def main(args):
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
world_size = int(os.getenv("WORLD_SIZE"))
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||||
train_dataset=get_custom_dataset(
|
train_dataset = get_custom_dataset(
|
||||||
path=config.train_dataset.path,
|
path=config.train_dataset.path,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
split="train",
|
split="train",
|
||||||
training_type="rl",
|
training_type="rl",
|
||||||
processor=processor
|
processor=processor,
|
||||||
)
|
)
|
||||||
valid_dataset=get_custom_dataset(
|
valid_dataset = get_custom_dataset(
|
||||||
path=config.valid_dataset.path,
|
path=config.valid_dataset.path,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
split="test",
|
split="test",
|
||||||
training_type="rl",
|
training_type="rl",
|
||||||
processor=processor
|
processor=processor,
|
||||||
)
|
)
|
||||||
# Create dataset and dataloaders
|
# Create dataset and dataloaders
|
||||||
train_dataloader = StatefulDataLoader(
|
train_dataloader = StatefulDataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
@ -141,7 +145,7 @@ def main(args):
|
||||||
actor.initialize(None, ft_spec)
|
actor.initialize(None, ft_spec)
|
||||||
ref = None
|
ref = None
|
||||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||||
ref =FSDPPPOActor(config=config.ref)
|
ref = FSDPPPOActor(config=config.ref)
|
||||||
ref.initialize(None, ft_spec)
|
ref.initialize(None, ft_spec)
|
||||||
|
|
||||||
# Create rollout workflow
|
# Create rollout workflow
|
||||||
|
@ -199,7 +203,7 @@ def main(args):
|
||||||
if ref is not None:
|
if ref is not None:
|
||||||
with stats_tracker.record_timing("ref_logp"):
|
with stats_tracker.record_timing("ref_logp"):
|
||||||
batch["ref_logp"] = ref.compute_logp(batch)
|
batch["ref_logp"] = ref.compute_logp(batch)
|
||||||
|
|
||||||
log_gpu_stats("ref logp")
|
log_gpu_stats("ref logp")
|
||||||
|
|
||||||
with stats_tracker.record_timing("compute_advantage"):
|
with stats_tracker.record_timing("compute_advantage"):
|
||||||
|
@ -211,8 +215,8 @@ def main(args):
|
||||||
stats_tracker.scope("grpo_actor"),
|
stats_tracker.scope("grpo_actor"),
|
||||||
):
|
):
|
||||||
stats = actor.ppo_update(batch)
|
stats = actor.ppo_update(batch)
|
||||||
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
|
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||||
|
|
||||||
actor.step_lr_scheduler()
|
actor.step_lr_scheduler()
|
||||||
log_gpu_stats("ppo update")
|
log_gpu_stats("ppo update")
|
||||||
|
|
||||||
|
@ -251,7 +255,14 @@ def main(args):
|
||||||
cnt += 1
|
cnt += 1
|
||||||
batch = eval_rollout.wait(cnt, timeout=None)
|
batch = eval_rollout.wait(cnt, timeout=None)
|
||||||
rewards = batch["rewards"].float().to(actor.device)
|
rewards = batch["rewards"].float().to(actor.device)
|
||||||
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
|
wandb.log(
|
||||||
|
{
|
||||||
|
"eval_reward": rewards.mean().item(),
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": step,
|
||||||
|
"global_step": global_step,
|
||||||
|
}
|
||||||
|
)
|
||||||
with stats_tracker.scope("grpo-eval"):
|
with stats_tracker.scope("grpo-eval"):
|
||||||
stats_tracker.denominator(
|
stats_tracker.denominator(
|
||||||
n_seqs=torch.ones(
|
n_seqs=torch.ones(
|
||||||
|
@ -280,6 +291,6 @@ def main(args):
|
||||||
actor.destroy()
|
actor.destroy()
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(sys.argv[1:])
|
main(sys.argv[1:])
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||||
from arealite.api.io_struct import FinetuneSpec
|
from arealite.api.io_struct import FinetuneSpec
|
||||||
|
from arealite.dataset.__init__ import get_custom_dataset
|
||||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||||
from arealite.utils.data import pad_sequences_to_tensors
|
from arealite.utils.data import pad_sequences_to_tensors
|
||||||
from arealite.utils.evaluator import Evaluator
|
from arealite.utils.evaluator import Evaluator
|
||||||
|
@ -11,7 +13,6 @@ from arealite.utils.saver import Saver
|
||||||
from arealite.utils.stats_logger import StatsLogger
|
from arealite.utils.stats_logger import StatsLogger
|
||||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||||
from realhf.base import stats_tracker
|
from realhf.base import stats_tracker
|
||||||
from arealite.dataset.__init__ import get_custom_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def main_sft():
|
def main_sft():
|
||||||
|
@ -21,25 +22,25 @@ def main_sft():
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
world_size = int(os.getenv("WORLD_SIZE"))
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||||
train_dataset=get_custom_dataset(
|
train_dataset = get_custom_dataset(
|
||||||
path=config.train_dataset.path,
|
path=config.train_dataset.path,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
split="train",
|
split="train",
|
||||||
training_type="sft",
|
training_type="sft",
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
valid_dataset=get_custom_dataset(
|
valid_dataset = get_custom_dataset(
|
||||||
path=config.valid_dataset.path,
|
path=config.valid_dataset.path,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
split="test",
|
split="test",
|
||||||
training_type="sft",
|
training_type="sft",
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create dataset and dataloaders
|
# Create dataset and dataloaders
|
||||||
train_dataloader = StatefulDataLoader(
|
train_dataloader = StatefulDataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
|
|
@ -1,26 +1,39 @@
|
||||||
from typing import Any, Dict, Optional, Union
|
import base64
|
||||||
import math
|
import math
|
||||||
from PIL.Image import Image as ImageObject
|
|
||||||
import os
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from datasets.distributed import split_dataset_by_node
|
from datasets.distributed import split_dataset_by_node
|
||||||
import base64
|
from PIL.Image import Image as ImageObject
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
def input_text(text:str):
|
|
||||||
|
def input_text(text: str):
|
||||||
return {"type": "input_text", "text": text}
|
return {"type": "input_text", "text": text}
|
||||||
|
|
||||||
|
|
||||||
def input_image(base64_image: str):
|
def input_image(base64_image: str):
|
||||||
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
|
return {
|
||||||
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
|
"type": "input_image",
|
||||||
|
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_raw_message(
|
||||||
|
sample: Dict[str, Any], base64_images: list[str]
|
||||||
|
) -> list[Dict[str, Any]]:
|
||||||
|
|
||||||
raw_message = []
|
raw_message = []
|
||||||
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
|
problem_parts = [
|
||||||
|
part.strip() for part in sample["problem"].split("<image>") if part.strip()
|
||||||
|
]
|
||||||
insert_list = []
|
insert_list = []
|
||||||
for i, part in enumerate(problem_parts):
|
for i, part in enumerate(problem_parts):
|
||||||
if i > 0 or sample["problem"].startswith("<image>"):
|
if i > 0 or sample["problem"].startswith("<image>"):
|
||||||
insert_list.append("image")
|
insert_list.append("image")
|
||||||
part = part.strip()
|
part = part.strip()
|
||||||
if part:
|
if part:
|
||||||
insert_list.append("text")
|
insert_list.append("text")
|
||||||
image_index = 0
|
image_index = 0
|
||||||
text_index = 0
|
text_index = 0
|
||||||
|
@ -38,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
|
||||||
|
|
||||||
def encode_image(image_file):
|
def encode_image(image_file):
|
||||||
return base64.b64encode(image_file).decode("utf-8")
|
return base64.b64encode(image_file).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def convert_image(
|
def convert_image(
|
||||||
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
|
image: Union[Dict[str, Any], ImageObject, str],
|
||||||
|
min_pixels: Optional[int],
|
||||||
|
max_pixels: Optional[int],
|
||||||
) -> ImageObject:
|
) -> ImageObject:
|
||||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
||||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
width, height = int(image.width * resize_factor), int(
|
||||||
|
image.height * resize_factor
|
||||||
|
)
|
||||||
image = image.resize((width, height))
|
image = image.resize((width, height))
|
||||||
|
|
||||||
if min_pixels is not None and (image.width * image.height) < min_pixels:
|
if min_pixels is not None and (image.width * image.height) < min_pixels:
|
||||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
width, height = int(image.width * resize_factor), int(
|
||||||
|
image.height * resize_factor
|
||||||
|
)
|
||||||
image = image.resize((width, height))
|
image = image.resize((width, height))
|
||||||
|
|
||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
|
@ -57,32 +78,34 @@ def convert_image(
|
||||||
image.save(output, format="JPEG")
|
image.save(output, format="JPEG")
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
'''
|
"""
|
||||||
"clevr_count_70k": {
|
"clevr_count_70k": {
|
||||||
"image_key": "images",
|
"image_key": "images",
|
||||||
"question_key": "problem",
|
"question_key": "problem",
|
||||||
"answer_key": "answer"
|
"answer_key": "answer"
|
||||||
},
|
},
|
||||||
'''
|
"""
|
||||||
dataset = load_dataset(path=path, split=split)
|
dataset = load_dataset(path=path, split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
tokenizer = processor.tokenizer
|
tokenizer = processor.tokenizer
|
||||||
|
|
||||||
def process_example(example, idx):
|
def process_example(example, idx):
|
||||||
# Add query_id column
|
# Add query_id column
|
||||||
images = example["images"]
|
images = example["images"]
|
||||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
else:
|
else:
|
||||||
image_token = processor.image_token if processor is not None else "<image>"
|
image_token = processor.image_token if processor is not None else "<image>"
|
||||||
example["problem"] = example["problem"].replace("<image>", image_token)
|
example["problem"] = example["problem"].replace("<image>", image_token)
|
||||||
processed_images = []
|
processed_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
processed_images.append(convert_image(image,113*113,336*336))
|
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
|
||||||
example["images"] = processed_images
|
example["images"] = processed_images
|
||||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
|
@ -91,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process(example):
|
def _process(example):
|
||||||
text=example["seq"]
|
text = example["seq"]
|
||||||
processed_input=processor(
|
processed_input = processor(
|
||||||
text=[text],
|
text=[text],
|
||||||
images=example["images"],
|
images=example["images"],
|
||||||
padding=False,
|
padding=False,
|
||||||
|
@ -101,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
return_attention_mask=False,
|
return_attention_mask=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
example["input_ids"] =processed_input["input_ids"].squeeze(0)
|
example["input_ids"] = processed_input["input_ids"].squeeze(0)
|
||||||
example["pixel_values"] = processed_input["pixel_values"]
|
example["pixel_values"] = processed_input["pixel_values"]
|
||||||
example["image_grid_thw"] = processed_input["image_grid_thw"]
|
example["image_grid_thw"] = processed_input["image_grid_thw"]
|
||||||
answer_token = tokenizer.encode(example["answer"])
|
answer_token = tokenizer.encode(example["answer"])
|
||||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
|
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
|
||||||
example["loss_mask"]=loss_mask
|
answer_token
|
||||||
|
)
|
||||||
|
example["loss_mask"] = loss_mask
|
||||||
return example
|
return example
|
||||||
|
|
||||||
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
|
dataset = dataset.map(
|
||||||
|
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
|
|
||||||
|
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
||||||
dataset = load_dataset(path=path, split=split)
|
dataset = load_dataset(path=path, split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def process(sample):
|
def process(sample):
|
||||||
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
|
processed_images = [
|
||||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
|
||||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
]
|
||||||
|
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||||
|
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
else:
|
else:
|
||||||
image_token = processor.image_token if processor is not None else "<image>"
|
image_token = processor.image_token if processor is not None else "<image>"
|
||||||
system_prompt = {
|
system_prompt = {
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": (
|
"content": (
|
||||||
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
|
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": sample["problem"].replace("<image>", image_token),
|
||||||
|
}
|
||||||
|
]
|
||||||
messages.insert(0, system_prompt)
|
messages.insert(0, system_prompt)
|
||||||
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
messages = processor.tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=False
|
||||||
|
)
|
||||||
return {"messages": messages, "images": processed_images}
|
return {"messages": messages, "images": processed_images}
|
||||||
|
|
||||||
dataset = dataset.map(process).remove_columns(["problem"])
|
dataset = dataset.map(process).remove_columns(["problem"])
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from datasets import load_dataset, Dataset
|
from datasets import Dataset, load_dataset
|
||||||
from datasets.distributed import split_dataset_by_node
|
from datasets.distributed import split_dataset_by_node
|
||||||
|
|
||||||
|
|
||||||
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
||||||
dataset = load_dataset(path=path, name="main", split=split)
|
dataset = load_dataset(path=path, name="main", split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def process(sample):
|
def process(sample):
|
||||||
seq_token = tokenizer.encode(
|
seq_token = tokenizer.encode(
|
||||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||||
|
@ -15,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
||||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def get_gsm8k_rl_dataset(path,split, rank, world_size):
|
|
||||||
|
def get_gsm8k_rl_dataset(path, split, rank, world_size):
|
||||||
dataset = load_dataset(path=path, name="main", split=split)
|
dataset = load_dataset(path=path, name="main", split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def process(sample):
|
def process(sample):
|
||||||
messages = [{"role": "user", "content": sample["question"]}]
|
messages = [{"role": "user", "content": sample["question"]}]
|
||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
dataset = dataset.map(process).remove_columns(["question"])
|
dataset = dataset.map(process).remove_columns(["question"])
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -68,6 +68,7 @@ def load_hf_tokenizer(
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def load_hf_processor_and_tokenizer(
|
def load_hf_processor_and_tokenizer(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
|
Loading…
Reference in New Issue