mirror of https://github.com/inclusionAI/AReaL
0721_merge6
This commit is contained in:
parent
aed6a9013c
commit
c29561498e
|
@ -0,0 +1,186 @@
|
|||
# Legacy codes
|
||||
.legacy/
|
||||
.data/
|
||||
.idea/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
trace_result/
|
||||
profile_result/
|
||||
|
||||
slurm_outs
|
||||
_data
|
||||
*.nfs*
|
||||
output
|
||||
logs
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# openai api key
|
||||
api_key.txt
|
||||
api_key.json
|
||||
|
||||
./*.sh
|
||||
*.jpg
|
||||
*.pdf
|
||||
*.swp
|
||||
|
||||
.vscode/
|
||||
wandb/
|
||||
outputs/
|
||||
sympy/
|
Binary file not shown.
|
@ -9,11 +9,10 @@ from typing import Any, Dict, List, Literal, Optional, Tuple
|
|||
|
||||
import torch
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
|
||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,146 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import transformers
|
||||
|
||||
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
|
||||
from arealite.api.io_struct import (
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
LLMServerInfo,
|
||||
WeightMeta,
|
||||
WeightUpdateGroupMeta,
|
||||
)
|
||||
from arealite.api.llm_server_api import LLMServiceRegistry
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
|
||||
|
||||
class LLMClient(abc.ABC):
|
||||
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
|
||||
self.args = args
|
||||
self.client_config = client_config
|
||||
|
||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
||||
self.tokenizer: transformers.PreTrainedTokenizerFast = load_hf_tokenizer(
|
||||
args.rollout.model_path
|
||||
)
|
||||
|
||||
def select_server(self):
|
||||
"""Get an available healthy server."""
|
||||
servers = self.get_healthy_servers()
|
||||
min_load = min([server.load for server in servers])
|
||||
servers = [server for server in servers if server.load == min_load]
|
||||
return random.choice(servers)
|
||||
|
||||
def get_healthy_servers(self):
|
||||
servers = self.registry.get_healthy_servers()
|
||||
if not servers:
|
||||
raise RuntimeError("No healthy SGLang servers available")
|
||||
return servers
|
||||
|
||||
def wait_until_servers_ready(self):
|
||||
while len(self.registry.get_healthy_servers()) == 0:
|
||||
time.sleep(10)
|
||||
|
||||
async def arequest_with_retry(
|
||||
self,
|
||||
endpoint: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
method: str = "POST",
|
||||
max_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
retry_delay: float = 1.0,
|
||||
target_server: Optional[LLMServerInfo] = None,
|
||||
) -> tuple[aiohttp.ClientResponse, LLMServerInfo]:
|
||||
timeout = timeout or self.client_config.request_timeout
|
||||
last_exception = None
|
||||
max_retries = max_retries or self.client_config.request_retries
|
||||
|
||||
# Try with retries
|
||||
for _ in range(max_retries):
|
||||
if target_server is None:
|
||||
server_info = self.select_server()
|
||||
else:
|
||||
server_info = target_server
|
||||
base_url = f"http://{server_info.host}:{server_info.port}"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=30,
|
||||
sock_read=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
response = await session.get(url)
|
||||
elif method.upper() == "POST":
|
||||
response = await session.post(url, json=payload)
|
||||
elif method.upper() == "PUT":
|
||||
response = await session.put(url, json=payload)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await session.delete(url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response, server_info
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientResponseError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
||||
)
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def ainit_weight_update_group(
|
||||
self, server_info: LLMServerInfo, group_meta: WeightUpdateGroupMeta
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aupdate_weights_from_distributed(
|
||||
self, server_info: LLMServerInfo, weight_meta: WeightMeta
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMClientFactory:
|
||||
"""Factory class to create LLMClient instances."""
|
||||
|
||||
args: TrainingArgs
|
||||
|
||||
def make_client(self, config: LLMClientConfig) -> LLMClient:
|
||||
"""Create an instance of LLMClient based on the specified type."""
|
||||
if self.args.rollout.server_backend == "sglang":
|
||||
from arealite.system.sglang_client import SGLangClient
|
||||
return SGLangClient(self.args, config)
|
||||
elif self.args.rollout.server_backend == "vl_sglang":
|
||||
from arealite.system.vl_sglang_client import VL_SGLangClient
|
||||
return VL_SGLangClient(self.args, config)
|
||||
|
||||
|
||||
raise ValueError(f"Unknown LLMClient type: {self.args.rollout.server_backend}")
|
|
@ -1,265 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from arealite.api.cli_args import LLMServiceConfig, TrainingArgs
|
||||
from arealite.api.io_struct import LLMServerInfo
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
logger = logging.getLogger("LLM Server")
|
||||
|
||||
|
||||
class LLMServiceRegistry:
|
||||
"""A registry class for dynamic server discovery."""
|
||||
|
||||
def __init__(self, expr_name: str, trial_name: str):
|
||||
self.expr_name = expr_name
|
||||
self.trial_name = trial_name
|
||||
self.heartbeat_timeout = 30
|
||||
|
||||
def get_server_key(self, server_id: str) -> str:
|
||||
return names.gen_server(self.expr_name, self.trial_name, server_id)
|
||||
|
||||
def register_server(self, server_info: LLMServerInfo):
|
||||
server_info.last_heartbeat = datetime.now().timestamp()
|
||||
key = self.get_server_key(server_info.server_id)
|
||||
name_resolve.add(
|
||||
key,
|
||||
json.dumps(asdict(server_info)),
|
||||
keepalive_ttl=self.heartbeat_timeout,
|
||||
replace=False,
|
||||
)
|
||||
|
||||
def unregister_server(self, server_id: str):
|
||||
try:
|
||||
name_resolve.delete(self.get_server_key(server_id))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
pass
|
||||
|
||||
def update_heartbeat(
|
||||
self, server_id: str, status: str, load: float = 0.0, version: int = 0
|
||||
):
|
||||
try:
|
||||
key = self.get_server_key(server_id)
|
||||
server_data = name_resolve.get(key)
|
||||
server_info = LLMServerInfo(**json.loads(server_data))
|
||||
server_info.last_heartbeat = datetime.now().timestamp()
|
||||
server_info.load = load
|
||||
server_info.status = status
|
||||
server_info.version = version
|
||||
name_resolve.add(
|
||||
key,
|
||||
json.dumps(asdict(server_info)),
|
||||
keepalive_ttl=self.heartbeat_timeout,
|
||||
replace=True,
|
||||
)
|
||||
except (name_resolve.NameEntryNotFoundError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
def get_healthy_servers(self) -> List[LLMServerInfo]:
|
||||
servers = []
|
||||
current_time = time.time()
|
||||
try:
|
||||
root = names.gen_server_root(self.expr_name, self.trial_name)
|
||||
server_infos = name_resolve.get_subtree(root)
|
||||
for server_data in server_infos:
|
||||
try:
|
||||
server_info = LLMServerInfo(**json.loads(server_data))
|
||||
if (
|
||||
current_time - server_info.last_heartbeat
|
||||
< self.heartbeat_timeout
|
||||
and server_info.status == "healthy"
|
||||
):
|
||||
servers.append(server_info)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
pass
|
||||
return servers
|
||||
|
||||
|
||||
class LLMServer:
|
||||
def __init__(self, args: TrainingArgs, service_config: LLMServiceConfig):
|
||||
self.args = args
|
||||
self.server_id = str(uuid.uuid4())
|
||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
||||
self.running = False
|
||||
self.load = 0.0
|
||||
self.process: Optional[subprocess.Popen] = None
|
||||
self.service_config = service_config
|
||||
|
||||
def launch_server(self) -> Optional[LLMServerInfo]:
|
||||
"""Launch the LLM server subprocess. Returns server info or None if failed."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""Check if the server is healthy."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start(self):
|
||||
"""Main entry point - start server and run until exit"""
|
||||
try:
|
||||
self._startup()
|
||||
self._run()
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._graceful_exit(1)
|
||||
|
||||
def _startup(self):
|
||||
"""Initialize and start the server"""
|
||||
self.running = True
|
||||
|
||||
# Launch server process
|
||||
server_info = self.launch_server()
|
||||
if not server_info or not self.process:
|
||||
raise RuntimeError("Failed to launch server")
|
||||
|
||||
logger.info(f"Server {self.server_id} starting")
|
||||
|
||||
# Wait for server to be ready
|
||||
if not self._wait_for_ready():
|
||||
raise RuntimeError(
|
||||
f"Server failed to become ready in {self.service_config.startup_timeout}s"
|
||||
)
|
||||
|
||||
# Register with service registry
|
||||
self.registry.register_server(server_info)
|
||||
|
||||
# Start health monitoring
|
||||
health_thread = threading.Thread(target=self._health_monitor, daemon=True)
|
||||
health_thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Server {self.server_id} ready and registered at http://{server_info.host}:{server_info.port}"
|
||||
)
|
||||
|
||||
def _wait_for_ready(self) -> bool:
|
||||
"""Wait for server to become healthy"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.service_config.startup_timeout:
|
||||
if not self.running or (self.process and self.process.poll() is not None):
|
||||
return False
|
||||
if self.check_health():
|
||||
return True
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
def _run(self):
|
||||
"""Main server loop"""
|
||||
try:
|
||||
while self.running:
|
||||
# Check if subprocess died
|
||||
if self.process and self.process.poll() is not None:
|
||||
logger.error(
|
||||
f"Server process died (code: {self.process.returncode})"
|
||||
)
|
||||
self._graceful_exit(1)
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt received")
|
||||
self._graceful_exit(0)
|
||||
|
||||
def _health_monitor(self):
|
||||
"""Monitor server health and exit if unhealthy"""
|
||||
failures = 0
|
||||
max_failures = self.service_config.max_unhealth_count
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Check process first
|
||||
if self.process and self.process.poll() is not None:
|
||||
logger.error("Server process died")
|
||||
self._graceful_exit(1)
|
||||
break
|
||||
|
||||
# Check health
|
||||
if self.check_health():
|
||||
failures = 0
|
||||
self.registry.update_heartbeat(self.server_id, "healthy", self.load)
|
||||
else:
|
||||
failures += 1
|
||||
logger.warning(f"Health check failed ({failures}/{max_failures})")
|
||||
|
||||
if failures >= max_failures:
|
||||
logger.error("Too many health check failures")
|
||||
self.registry.update_heartbeat(
|
||||
self.server_id, "unhealthy", self.load
|
||||
)
|
||||
if self.service_config.graceful_shutdown_on_unhealthy:
|
||||
self._graceful_exit(1)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health monitor error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
failures += 1
|
||||
if (
|
||||
failures >= max_failures
|
||||
and self.service_config.graceful_shutdown_on_unhealthy
|
||||
):
|
||||
self._graceful_exit(1)
|
||||
break
|
||||
|
||||
time.sleep(self.service_config.health_check_interval)
|
||||
|
||||
def _graceful_exit(self, exit_code: int):
|
||||
"""Clean shutdown and exit"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info(f"Graceful shutdown initiated (exit code: {exit_code})")
|
||||
self.running = False
|
||||
|
||||
# Cleanup registry
|
||||
try:
|
||||
self.registry.unregister_server(self.server_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Registry cleanup failed: {e}")
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
# Stop process
|
||||
if self.process and self.process.poll() is None:
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.wait(timeout=5)
|
||||
logger.info("Server terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Force killing server")
|
||||
try:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
except (ProcessLookupError, OSError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Process cleanup failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if exit_code != 0:
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMServerFactory:
|
||||
args: TrainingArgs
|
||||
|
||||
def make_server(self, server_config: LLMServiceConfig) -> LLMServer:
|
||||
"""Create an LLM server instance based on the configuration."""
|
||||
if self.args.rollout.server_backend == "sglang":
|
||||
from arealite.system.sglang_server import SGLangServer
|
||||
|
||||
return SGLangServer(self.args, server_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported server backend: {self.args.rollout.server_backend}"
|
||||
)
|
|
@ -12,6 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
|||
|
||||
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
||||
from realhf.base import constants
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
|
||||
|
||||
|
||||
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
|
||||
from arealite.api.llm_server_api import LLMServiceRegistry
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
|
||||
class VLMClient(LLMClient):
|
||||
"""A client for interacting with VLM servers."""
|
||||
|
||||
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
|
||||
super().__init__(args, client_config)
|
||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||
args.rollout.model_path
|
||||
)
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||
|
||||
def get_custom_dataset(
|
||||
|
@ -20,10 +22,14 @@ def get_custom_dataset(
|
|||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "sft":
|
||||
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_rl_dataset
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from arealite.engine.ppo.actor import PPOActor, PPOActorConfig
|
||||
from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,30 +10,19 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import (
|
||||
AutoModelForImageTextToText,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
SaveLoadMeta,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.data import (
|
||||
MicroBatchList,
|
||||
amend_position_ids,
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
|
@ -41,10 +30,9 @@ from arealite.utils.fsdp import (
|
|||
create_fsdp_device_mesh,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import time
|
||||
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.io_struct import VLMRequest, VLMResponse
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.api.io_struct import (
|
||||
VLMRequest,
|
||||
VLMResponse
|
||||
)
|
||||
from arealite.utils.http import arequest_with_retry
|
||||
from realhf.base import logging, pkg_version
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,19 +0,0 @@
|
|||
# Use a pipeline as a high-level helper
|
||||
from transformers import AutoProcessor
|
||||
processor=AutoProcessor.from_pretrained(pretrained_model_name_or_path="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct")
|
||||
input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 4340, 1657, 3589, 525, 1052, 304, 279, 2168, 30, 151645, 198, 151644, 77091, 198]
|
||||
|
||||
decoded_text = processor.tokenizer.decode(input_ids)
|
||||
print(decoded_text)
|
||||
|
||||
# pipe = pipeline("image-text-to-text", model="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct",device_map="auto" )
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "image", "url": "output_image.jpg"},
|
||||
# {"type": "text", "text": "What is shown in the image??"}
|
||||
# ]
|
||||
# },
|
||||
# ]
|
||||
# print(pipe(text=messages))
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
from PIL import Image
|
||||
|
||||
import aiohttp
|
||||
|
||||
from realhf.base import logging
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from io import BytesIO
|
||||
import base64
|
||||
import math
|
||||
from torch import Tensor
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from PIL.Image import Image as ImageObject
|
||||
from dataclasses import MISSING
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import getpass
|
||||
import os
|
||||
|
||||
from transformers import PreTrainedTokenizerFast, AutoProcessor
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import SaverConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
|
|
|
@ -3,12 +3,14 @@ import uuid
|
|||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import VLMRequest
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.image import image2base64
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
|
||||
class VL_RLVRWorkflow(RLVRWorkflow):
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue