mirror of https://github.com/inclusionAI/AReaL
0721_merge6
This commit is contained in:
parent
aed6a9013c
commit
c29561498e
|
@ -0,0 +1,186 @@
|
||||||
|
# Legacy codes
|
||||||
|
.legacy/
|
||||||
|
.data/
|
||||||
|
.idea/
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
trace_result/
|
||||||
|
profile_result/
|
||||||
|
|
||||||
|
slurm_outs
|
||||||
|
_data
|
||||||
|
*.nfs*
|
||||||
|
output
|
||||||
|
logs
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# openai api key
|
||||||
|
api_key.txt
|
||||||
|
api_key.json
|
||||||
|
|
||||||
|
./*.sh
|
||||||
|
*.jpg
|
||||||
|
*.pdf
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
.vscode/
|
||||||
|
wandb/
|
||||||
|
outputs/
|
||||||
|
sympy/
|
Binary file not shown.
|
@ -9,11 +9,10 @@ from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from arealite.api.cli_args import GenerationHyperparameters
|
|
||||||
from PIL.Image import Image as ImageObject
|
from PIL.Image import Image as ImageObject
|
||||||
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from arealite.api.cli_args import GenerationHyperparameters
|
||||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,146 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import requests
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
|
|
||||||
from arealite.api.io_struct import (
|
|
||||||
LLMRequest,
|
|
||||||
LLMResponse,
|
|
||||||
LLMServerInfo,
|
|
||||||
WeightMeta,
|
|
||||||
WeightUpdateGroupMeta,
|
|
||||||
)
|
|
||||||
from arealite.api.llm_server_api import LLMServiceRegistry
|
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient(abc.ABC):
|
|
||||||
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
|
|
||||||
self.args = args
|
|
||||||
self.client_config = client_config
|
|
||||||
|
|
||||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
|
||||||
self.tokenizer: transformers.PreTrainedTokenizerFast = load_hf_tokenizer(
|
|
||||||
args.rollout.model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def select_server(self):
|
|
||||||
"""Get an available healthy server."""
|
|
||||||
servers = self.get_healthy_servers()
|
|
||||||
min_load = min([server.load for server in servers])
|
|
||||||
servers = [server for server in servers if server.load == min_load]
|
|
||||||
return random.choice(servers)
|
|
||||||
|
|
||||||
def get_healthy_servers(self):
|
|
||||||
servers = self.registry.get_healthy_servers()
|
|
||||||
if not servers:
|
|
||||||
raise RuntimeError("No healthy SGLang servers available")
|
|
||||||
return servers
|
|
||||||
|
|
||||||
def wait_until_servers_ready(self):
|
|
||||||
while len(self.registry.get_healthy_servers()) == 0:
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
async def arequest_with_retry(
|
|
||||||
self,
|
|
||||||
endpoint: str,
|
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
|
||||||
method: str = "POST",
|
|
||||||
max_retries: Optional[int] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
target_server: Optional[LLMServerInfo] = None,
|
|
||||||
) -> tuple[aiohttp.ClientResponse, LLMServerInfo]:
|
|
||||||
timeout = timeout or self.client_config.request_timeout
|
|
||||||
last_exception = None
|
|
||||||
max_retries = max_retries or self.client_config.request_retries
|
|
||||||
|
|
||||||
# Try with retries
|
|
||||||
for _ in range(max_retries):
|
|
||||||
if target_server is None:
|
|
||||||
server_info = self.select_server()
|
|
||||||
else:
|
|
||||||
server_info = target_server
|
|
||||||
base_url = f"http://{server_info.host}:{server_info.port}"
|
|
||||||
url = f"{base_url}{endpoint}"
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
total=timeout,
|
|
||||||
sock_connect=30,
|
|
||||||
sock_read=timeout,
|
|
||||||
)
|
|
||||||
) as session:
|
|
||||||
if method.upper() == "GET":
|
|
||||||
response = await session.get(url)
|
|
||||||
elif method.upper() == "POST":
|
|
||||||
response = await session.post(url, json=payload)
|
|
||||||
elif method.upper() == "PUT":
|
|
||||||
response = await session.put(url, json=payload)
|
|
||||||
elif method.upper() == "DELETE":
|
|
||||||
response = await session.delete(url)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
return response, server_info
|
|
||||||
|
|
||||||
except (
|
|
||||||
aiohttp.ClientError,
|
|
||||||
aiohttp.ClientResponseError,
|
|
||||||
asyncio.TimeoutError,
|
|
||||||
) as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
continue
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def ainit_weight_update_group(
|
|
||||||
self, server_info: LLMServerInfo, group_meta: WeightUpdateGroupMeta
|
|
||||||
):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def aupdate_weights_from_distributed(
|
|
||||||
self, server_info: LLMServerInfo, weight_meta: WeightMeta
|
|
||||||
):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMClientFactory:
|
|
||||||
"""Factory class to create LLMClient instances."""
|
|
||||||
|
|
||||||
args: TrainingArgs
|
|
||||||
|
|
||||||
def make_client(self, config: LLMClientConfig) -> LLMClient:
|
|
||||||
"""Create an instance of LLMClient based on the specified type."""
|
|
||||||
if self.args.rollout.server_backend == "sglang":
|
|
||||||
from arealite.system.sglang_client import SGLangClient
|
|
||||||
return SGLangClient(self.args, config)
|
|
||||||
elif self.args.rollout.server_backend == "vl_sglang":
|
|
||||||
from arealite.system.vl_sglang_client import VL_SGLangClient
|
|
||||||
return VL_SGLangClient(self.args, config)
|
|
||||||
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLMClient type: {self.args.rollout.server_backend}")
|
|
|
@ -1,265 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from arealite.api.cli_args import LLMServiceConfig, TrainingArgs
|
|
||||||
from arealite.api.io_struct import LLMServerInfo
|
|
||||||
from realhf.base import logging, name_resolve, names
|
|
||||||
|
|
||||||
logger = logging.getLogger("LLM Server")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMServiceRegistry:
|
|
||||||
"""A registry class for dynamic server discovery."""
|
|
||||||
|
|
||||||
def __init__(self, expr_name: str, trial_name: str):
|
|
||||||
self.expr_name = expr_name
|
|
||||||
self.trial_name = trial_name
|
|
||||||
self.heartbeat_timeout = 30
|
|
||||||
|
|
||||||
def get_server_key(self, server_id: str) -> str:
|
|
||||||
return names.gen_server(self.expr_name, self.trial_name, server_id)
|
|
||||||
|
|
||||||
def register_server(self, server_info: LLMServerInfo):
|
|
||||||
server_info.last_heartbeat = datetime.now().timestamp()
|
|
||||||
key = self.get_server_key(server_info.server_id)
|
|
||||||
name_resolve.add(
|
|
||||||
key,
|
|
||||||
json.dumps(asdict(server_info)),
|
|
||||||
keepalive_ttl=self.heartbeat_timeout,
|
|
||||||
replace=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def unregister_server(self, server_id: str):
|
|
||||||
try:
|
|
||||||
name_resolve.delete(self.get_server_key(server_id))
|
|
||||||
except name_resolve.NameEntryNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_heartbeat(
|
|
||||||
self, server_id: str, status: str, load: float = 0.0, version: int = 0
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
key = self.get_server_key(server_id)
|
|
||||||
server_data = name_resolve.get(key)
|
|
||||||
server_info = LLMServerInfo(**json.loads(server_data))
|
|
||||||
server_info.last_heartbeat = datetime.now().timestamp()
|
|
||||||
server_info.load = load
|
|
||||||
server_info.status = status
|
|
||||||
server_info.version = version
|
|
||||||
name_resolve.add(
|
|
||||||
key,
|
|
||||||
json.dumps(asdict(server_info)),
|
|
||||||
keepalive_ttl=self.heartbeat_timeout,
|
|
||||||
replace=True,
|
|
||||||
)
|
|
||||||
except (name_resolve.NameEntryNotFoundError, json.JSONDecodeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_healthy_servers(self) -> List[LLMServerInfo]:
|
|
||||||
servers = []
|
|
||||||
current_time = time.time()
|
|
||||||
try:
|
|
||||||
root = names.gen_server_root(self.expr_name, self.trial_name)
|
|
||||||
server_infos = name_resolve.get_subtree(root)
|
|
||||||
for server_data in server_infos:
|
|
||||||
try:
|
|
||||||
server_info = LLMServerInfo(**json.loads(server_data))
|
|
||||||
if (
|
|
||||||
current_time - server_info.last_heartbeat
|
|
||||||
< self.heartbeat_timeout
|
|
||||||
and server_info.status == "healthy"
|
|
||||||
):
|
|
||||||
servers.append(server_info)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
continue
|
|
||||||
except name_resolve.NameEntryNotFoundError:
|
|
||||||
pass
|
|
||||||
return servers
|
|
||||||
|
|
||||||
|
|
||||||
class LLMServer:
|
|
||||||
def __init__(self, args: TrainingArgs, service_config: LLMServiceConfig):
|
|
||||||
self.args = args
|
|
||||||
self.server_id = str(uuid.uuid4())
|
|
||||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
|
||||||
self.running = False
|
|
||||||
self.load = 0.0
|
|
||||||
self.process: Optional[subprocess.Popen] = None
|
|
||||||
self.service_config = service_config
|
|
||||||
|
|
||||||
def launch_server(self) -> Optional[LLMServerInfo]:
|
|
||||||
"""Launch the LLM server subprocess. Returns server info or None if failed."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def check_health(self) -> bool:
|
|
||||||
"""Check if the server is healthy."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Main entry point - start server and run until exit"""
|
|
||||||
try:
|
|
||||||
self._startup()
|
|
||||||
self._run()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Server error: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
self._graceful_exit(1)
|
|
||||||
|
|
||||||
def _startup(self):
|
|
||||||
"""Initialize and start the server"""
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
# Launch server process
|
|
||||||
server_info = self.launch_server()
|
|
||||||
if not server_info or not self.process:
|
|
||||||
raise RuntimeError("Failed to launch server")
|
|
||||||
|
|
||||||
logger.info(f"Server {self.server_id} starting")
|
|
||||||
|
|
||||||
# Wait for server to be ready
|
|
||||||
if not self._wait_for_ready():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Server failed to become ready in {self.service_config.startup_timeout}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register with service registry
|
|
||||||
self.registry.register_server(server_info)
|
|
||||||
|
|
||||||
# Start health monitoring
|
|
||||||
health_thread = threading.Thread(target=self._health_monitor, daemon=True)
|
|
||||||
health_thread.start()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Server {self.server_id} ready and registered at http://{server_info.host}:{server_info.port}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _wait_for_ready(self) -> bool:
|
|
||||||
"""Wait for server to become healthy"""
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < self.service_config.startup_timeout:
|
|
||||||
if not self.running or (self.process and self.process.poll() is not None):
|
|
||||||
return False
|
|
||||||
if self.check_health():
|
|
||||||
return True
|
|
||||||
time.sleep(2)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
"""Main server loop"""
|
|
||||||
try:
|
|
||||||
while self.running:
|
|
||||||
# Check if subprocess died
|
|
||||||
if self.process and self.process.poll() is not None:
|
|
||||||
logger.error(
|
|
||||||
f"Server process died (code: {self.process.returncode})"
|
|
||||||
)
|
|
||||||
self._graceful_exit(1)
|
|
||||||
time.sleep(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Keyboard interrupt received")
|
|
||||||
self._graceful_exit(0)
|
|
||||||
|
|
||||||
def _health_monitor(self):
|
|
||||||
"""Monitor server health and exit if unhealthy"""
|
|
||||||
failures = 0
|
|
||||||
max_failures = self.service_config.max_unhealth_count
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
# Check process first
|
|
||||||
if self.process and self.process.poll() is not None:
|
|
||||||
logger.error("Server process died")
|
|
||||||
self._graceful_exit(1)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check health
|
|
||||||
if self.check_health():
|
|
||||||
failures = 0
|
|
||||||
self.registry.update_heartbeat(self.server_id, "healthy", self.load)
|
|
||||||
else:
|
|
||||||
failures += 1
|
|
||||||
logger.warning(f"Health check failed ({failures}/{max_failures})")
|
|
||||||
|
|
||||||
if failures >= max_failures:
|
|
||||||
logger.error("Too many health check failures")
|
|
||||||
self.registry.update_heartbeat(
|
|
||||||
self.server_id, "unhealthy", self.load
|
|
||||||
)
|
|
||||||
if self.service_config.graceful_shutdown_on_unhealthy:
|
|
||||||
self._graceful_exit(1)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Health monitor error: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
failures += 1
|
|
||||||
if (
|
|
||||||
failures >= max_failures
|
|
||||||
and self.service_config.graceful_shutdown_on_unhealthy
|
|
||||||
):
|
|
||||||
self._graceful_exit(1)
|
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(self.service_config.health_check_interval)
|
|
||||||
|
|
||||||
def _graceful_exit(self, exit_code: int):
|
|
||||||
"""Clean shutdown and exit"""
|
|
||||||
if not self.running:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Graceful shutdown initiated (exit code: {exit_code})")
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
# Cleanup registry
|
|
||||||
try:
|
|
||||||
self.registry.unregister_server(self.server_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Registry cleanup failed: {e}")
|
|
||||||
logger.warning(traceback.format_exc())
|
|
||||||
|
|
||||||
# Stop process
|
|
||||||
if self.process and self.process.poll() is None:
|
|
||||||
try:
|
|
||||||
self.process.terminate()
|
|
||||||
self.process.wait(timeout=5)
|
|
||||||
logger.info("Server terminated gracefully")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning("Force killing server")
|
|
||||||
try:
|
|
||||||
self.process.kill()
|
|
||||||
self.process.wait()
|
|
||||||
except (ProcessLookupError, OSError):
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Process cleanup failed: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
if exit_code != 0:
|
|
||||||
sys.exit(exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMServerFactory:
|
|
||||||
args: TrainingArgs
|
|
||||||
|
|
||||||
def make_server(self, server_config: LLMServiceConfig) -> LLMServer:
|
|
||||||
"""Create an LLM server instance based on the configuration."""
|
|
||||||
if self.args.rollout.server_backend == "sglang":
|
|
||||||
from arealite.system.sglang_server import SGLangServer
|
|
||||||
|
|
||||||
return SGLangServer(self.args, server_config)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported server backend: {self.args.rollout.server_backend}"
|
|
||||||
)
|
|
|
@ -12,6 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
||||||
from realhf.base import constants
|
from realhf.base import constants
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from arealite.system.rollout_controller import RolloutController
|
from arealite.system.rollout_controller import RolloutController
|
||||||
|
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
|
|
||||||
from arealite.api.llm_server_api import LLMServiceRegistry
|
|
||||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
|
||||||
from arealite.api.llm_client_api import LLMClient
|
|
||||||
|
|
||||||
class VLMClient(LLMClient):
|
|
||||||
"""A client for interacting with VLM servers."""
|
|
||||||
|
|
||||||
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
|
|
||||||
super().__init__(args, client_config)
|
|
||||||
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
|
|
||||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
|
||||||
args.rollout.model_path
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||||
|
|
||||||
def get_custom_dataset(
|
def get_custom_dataset(
|
||||||
|
@ -20,10 +22,14 @@ def get_custom_dataset(
|
||||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||||
elif "clevr_count_70k" in path and training_type == "sft":
|
elif "clevr_count_70k" in path and training_type == "sft":
|
||||||
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset
|
from examples.arealite.dataset.clevr_count_70k import (
|
||||||
|
get_clevr_count_70k_sft_dataset,
|
||||||
|
)
|
||||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||||
elif "clevr_count_70k" in path and training_type == "rl":
|
elif "clevr_count_70k" in path and training_type == "rl":
|
||||||
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_rl_dataset
|
from examples.arealite.dataset.clevr_count_70k import (
|
||||||
|
get_clevr_count_70k_rl_dataset,
|
||||||
|
)
|
||||||
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from arealite.engine.ppo.actor import PPOActor, PPOActorConfig
|
from arealite.engine.ppo.actor import PPOActor, PPOActorConfig
|
||||||
from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
|
from arealite.engine.vl_fsdp_engine import VL_FSDPEngine
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -10,30 +10,19 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
get_model_state_dict,
|
get_model_state_dict,
|
||||||
)
|
)
|
||||||
from transformers import (
|
from transformers import AutoModelForImageTextToText
|
||||||
AutoModelForImageTextToText,
|
|
||||||
get_constant_schedule_with_warmup,
|
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import (
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
FinetuneSpec,
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
SaveLoadMeta,
|
|
||||||
WeightUpdateMeta,
|
|
||||||
)
|
|
||||||
from arealite.utils.data import (
|
from arealite.utils.data import (
|
||||||
MicroBatchList,
|
MicroBatchList,
|
||||||
amend_position_ids,
|
amend_position_ids,
|
||||||
pack_tensor_dict,
|
pack_tensor_dict,
|
||||||
pad_and_stack_tensors_along_first_dim,
|
|
||||||
pad_mb_list,
|
pad_mb_list,
|
||||||
reorder_list,
|
|
||||||
split_padded_tensor_dict_into_mb_list,
|
split_padded_tensor_dict_into_mb_list,
|
||||||
unpack_sequence,
|
|
||||||
unsqueeze_mb_list,
|
unsqueeze_mb_list,
|
||||||
)
|
)
|
||||||
from arealite.utils.model import disable_dropout_in_model
|
|
||||||
from arealite.utils.fsdp import (
|
from arealite.utils.fsdp import (
|
||||||
CPUOffloadPolicy,
|
CPUOffloadPolicy,
|
||||||
MixedPrecisionPolicy,
|
MixedPrecisionPolicy,
|
||||||
|
@ -41,10 +30,9 @@ from arealite.utils.fsdp import (
|
||||||
create_fsdp_device_mesh,
|
create_fsdp_device_mesh,
|
||||||
get_cosine_schedule_with_warmup,
|
get_cosine_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from arealite.utils.model import disable_dropout_in_model
|
||||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("FSDPEngine")
|
logger = logging.getLogger("FSDPEngine")
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from arealite.api.cli_args import InferenceEngineConfig
|
from arealite.api.cli_args import InferenceEngineConfig
|
||||||
|
from arealite.api.io_struct import VLMRequest, VLMResponse
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
from arealite.api.io_struct import (
|
|
||||||
VLMRequest,
|
|
||||||
VLMResponse
|
|
||||||
)
|
|
||||||
from arealite.utils.http import arequest_with_retry
|
from arealite.utils.http import arequest_with_retry
|
||||||
from realhf.base import logging, pkg_version
|
from realhf.base import logging, pkg_version
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,19 +0,0 @@
|
||||||
# Use a pipeline as a high-level helper
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
processor=AutoProcessor.from_pretrained(pretrained_model_name_or_path="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct")
|
|
||||||
input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 4340, 1657, 3589, 525, 1052, 304, 279, 2168, 30, 151645, 198, 151644, 77091, 198]
|
|
||||||
|
|
||||||
decoded_text = processor.tokenizer.decode(input_ids)
|
|
||||||
print(decoded_text)
|
|
||||||
|
|
||||||
# pipe = pipeline("image-text-to-text", model="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct",device_map="auto" )
|
|
||||||
# messages = [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": [
|
|
||||||
# {"type": "image", "url": "output_image.jpg"},
|
|
||||||
# {"type": "text", "text": "What is shown in the image??"}
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# ]
|
|
||||||
# print(pipe(text=messages))
|
|
|
@ -1,6 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from PIL import Image
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from realhf.base import logging
|
from realhf.base import logging
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from io import BytesIO
|
|
||||||
import base64
|
import base64
|
||||||
import math
|
import math
|
||||||
from torch import Tensor
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
from PIL.Image import Image as ImageObject
|
|
||||||
from dataclasses import MISSING
|
from dataclasses import MISSING
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerFast, AutoProcessor
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import SaverConfig
|
from arealite.api.cli_args import SaverConfig
|
||||||
from arealite.api.engine_api import TrainEngine
|
from arealite.api.engine_api import TrainEngine
|
||||||
|
|
|
@ -3,12 +3,14 @@ import uuid
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from transformers import PreTrainedTokenizerFast,AutoProcessor
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import GenerationHyperparameters
|
from arealite.api.cli_args import GenerationHyperparameters
|
||||||
from arealite.api.io_struct import VLMRequest
|
from arealite.api.io_struct import VLMRequest
|
||||||
from arealite.workflow.rlvr import RLVRWorkflow
|
from arealite.utils.data import concat_padded_tensors
|
||||||
from arealite.utils.padding import concat_padded_tensors
|
|
||||||
from arealite.utils.image import image2base64
|
from arealite.utils.image import image2base64
|
||||||
|
from arealite.workflow.rlvr import RLVRWorkflow
|
||||||
|
|
||||||
|
|
||||||
class VL_RLVRWorkflow(RLVRWorkflow):
|
class VL_RLVRWorkflow(RLVRWorkflow):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
Loading…
Reference in New Issue