0721_merge6

This commit is contained in:
朱晗 2025-07-21 17:57:29 +08:00
parent aed6a9013c
commit c29561498e
17 changed files with 218 additions and 505 deletions

186
.gitignore vendored Normal file
View File

@ -0,0 +1,186 @@
# Legacy codes
.legacy/
.data/
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
trace_result/
profile_result/
slurm_outs
_data
*.nfs*
output
logs
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# openai api key
api_key.txt
api_key.json
./*.sh
*.jpg
*.pdf
*.swp
.vscode/
wandb/
outputs/
sympy/

View File

@ -9,11 +9,10 @@ from typing import Any, Dict, List, Literal, Optional, Tuple
import torch import torch
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
from arealite.api.cli_args import GenerationHyperparameters
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from transformers import PreTrainedTokenizerFast,AutoProcessor
@dataclass @dataclass

View File

@ -1,146 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import abc
import asyncio
import random
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional
import aiohttp
import requests
import transformers
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
LLMServerInfo,
WeightMeta,
WeightUpdateGroupMeta,
)
from arealite.api.llm_server_api import LLMServiceRegistry
from realhf.api.core.data_api import load_hf_tokenizer
class LLMClient(abc.ABC):
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
self.args = args
self.client_config = client_config
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
self.tokenizer: transformers.PreTrainedTokenizerFast = load_hf_tokenizer(
args.rollout.model_path
)
def select_server(self):
"""Get an available healthy server."""
servers = self.get_healthy_servers()
min_load = min([server.load for server in servers])
servers = [server for server in servers if server.load == min_load]
return random.choice(servers)
def get_healthy_servers(self):
servers = self.registry.get_healthy_servers()
if not servers:
raise RuntimeError("No healthy SGLang servers available")
return servers
def wait_until_servers_ready(self):
while len(self.registry.get_healthy_servers()) == 0:
time.sleep(10)
async def arequest_with_retry(
self,
endpoint: str,
payload: Optional[Dict[str, Any]] = None,
method: str = "POST",
max_retries: Optional[int] = None,
timeout: Optional[float] = None,
retry_delay: float = 1.0,
target_server: Optional[LLMServerInfo] = None,
) -> tuple[aiohttp.ClientResponse, LLMServerInfo]:
timeout = timeout or self.client_config.request_timeout
last_exception = None
max_retries = max_retries or self.client_config.request_retries
# Try with retries
for _ in range(max_retries):
if target_server is None:
server_info = self.select_server()
else:
server_info = target_server
base_url = f"http://{server_info.host}:{server_info.port}"
url = f"{base_url}{endpoint}"
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=30,
sock_read=timeout,
)
) as session:
if method.upper() == "GET":
response = await session.get(url)
elif method.upper() == "POST":
response = await session.post(url, json=payload)
elif method.upper() == "PUT":
response = await session.put(url, json=payload)
elif method.upper() == "DELETE":
response = await session.delete(url)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response, server_info
except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
asyncio.TimeoutError,
) as e:
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
)
async def agenerate(self, req: LLMRequest) -> LLMResponse:
raise NotImplementedError()
async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str):
raise NotImplementedError()
async def ainit_weight_update_group(
self, server_info: LLMServerInfo, group_meta: WeightUpdateGroupMeta
):
raise NotImplementedError()
async def aupdate_weights_from_distributed(
self, server_info: LLMServerInfo, weight_meta: WeightMeta
):
raise NotImplementedError()
@dataclass
class LLMClientFactory:
"""Factory class to create LLMClient instances."""
args: TrainingArgs
def make_client(self, config: LLMClientConfig) -> LLMClient:
"""Create an instance of LLMClient based on the specified type."""
if self.args.rollout.server_backend == "sglang":
from arealite.system.sglang_client import SGLangClient
return SGLangClient(self.args, config)
elif self.args.rollout.server_backend == "vl_sglang":
from arealite.system.vl_sglang_client import VL_SGLangClient
return VL_SGLangClient(self.args, config)
raise ValueError(f"Unknown LLMClient type: {self.args.rollout.server_backend}")

View File

@ -1,265 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import json
import subprocess
import sys
import threading
import time
import traceback
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import List, Optional
from arealite.api.cli_args import LLMServiceConfig, TrainingArgs
from arealite.api.io_struct import LLMServerInfo
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("LLM Server")
class LLMServiceRegistry:
"""A registry class for dynamic server discovery."""
def __init__(self, expr_name: str, trial_name: str):
self.expr_name = expr_name
self.trial_name = trial_name
self.heartbeat_timeout = 30
def get_server_key(self, server_id: str) -> str:
return names.gen_server(self.expr_name, self.trial_name, server_id)
def register_server(self, server_info: LLMServerInfo):
server_info.last_heartbeat = datetime.now().timestamp()
key = self.get_server_key(server_info.server_id)
name_resolve.add(
key,
json.dumps(asdict(server_info)),
keepalive_ttl=self.heartbeat_timeout,
replace=False,
)
def unregister_server(self, server_id: str):
try:
name_resolve.delete(self.get_server_key(server_id))
except name_resolve.NameEntryNotFoundError:
pass
def update_heartbeat(
self, server_id: str, status: str, load: float = 0.0, version: int = 0
):
try:
key = self.get_server_key(server_id)
server_data = name_resolve.get(key)
server_info = LLMServerInfo(**json.loads(server_data))
server_info.last_heartbeat = datetime.now().timestamp()
server_info.load = load
server_info.status = status
server_info.version = version
name_resolve.add(
key,
json.dumps(asdict(server_info)),
keepalive_ttl=self.heartbeat_timeout,
replace=True,
)
except (name_resolve.NameEntryNotFoundError, json.JSONDecodeError):
pass
def get_healthy_servers(self) -> List[LLMServerInfo]:
servers = []
current_time = time.time()
try:
root = names.gen_server_root(self.expr_name, self.trial_name)
server_infos = name_resolve.get_subtree(root)
for server_data in server_infos:
try:
server_info = LLMServerInfo(**json.loads(server_data))
if (
current_time - server_info.last_heartbeat
< self.heartbeat_timeout
and server_info.status == "healthy"
):
servers.append(server_info)
except (json.JSONDecodeError, TypeError):
continue
except name_resolve.NameEntryNotFoundError:
pass
return servers
class LLMServer:
def __init__(self, args: TrainingArgs, service_config: LLMServiceConfig):
self.args = args
self.server_id = str(uuid.uuid4())
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
self.running = False
self.load = 0.0
self.process: Optional[subprocess.Popen] = None
self.service_config = service_config
def launch_server(self) -> Optional[LLMServerInfo]:
"""Launch the LLM server subprocess. Returns server info or None if failed."""
raise NotImplementedError()
def check_health(self) -> bool:
"""Check if the server is healthy."""
raise NotImplementedError()
def start(self):
"""Main entry point - start server and run until exit"""
try:
self._startup()
self._run()
except Exception as e:
logger.error(f"Server error: {e}")
logger.error(traceback.format_exc())
self._graceful_exit(1)
def _startup(self):
"""Initialize and start the server"""
self.running = True
# Launch server process
server_info = self.launch_server()
if not server_info or not self.process:
raise RuntimeError("Failed to launch server")
logger.info(f"Server {self.server_id} starting")
# Wait for server to be ready
if not self._wait_for_ready():
raise RuntimeError(
f"Server failed to become ready in {self.service_config.startup_timeout}s"
)
# Register with service registry
self.registry.register_server(server_info)
# Start health monitoring
health_thread = threading.Thread(target=self._health_monitor, daemon=True)
health_thread.start()
logger.info(
f"Server {self.server_id} ready and registered at http://{server_info.host}:{server_info.port}"
)
def _wait_for_ready(self) -> bool:
"""Wait for server to become healthy"""
start_time = time.time()
while time.time() - start_time < self.service_config.startup_timeout:
if not self.running or (self.process and self.process.poll() is not None):
return False
if self.check_health():
return True
time.sleep(2)
return False
def _run(self):
"""Main server loop"""
try:
while self.running:
# Check if subprocess died
if self.process and self.process.poll() is not None:
logger.error(
f"Server process died (code: {self.process.returncode})"
)
self._graceful_exit(1)
time.sleep(1)
except KeyboardInterrupt:
logger.info("Keyboard interrupt received")
self._graceful_exit(0)
def _health_monitor(self):
"""Monitor server health and exit if unhealthy"""
failures = 0
max_failures = self.service_config.max_unhealth_count
while self.running:
try:
# Check process first
if self.process and self.process.poll() is not None:
logger.error("Server process died")
self._graceful_exit(1)
break
# Check health
if self.check_health():
failures = 0
self.registry.update_heartbeat(self.server_id, "healthy", self.load)
else:
failures += 1
logger.warning(f"Health check failed ({failures}/{max_failures})")
if failures >= max_failures:
logger.error("Too many health check failures")
self.registry.update_heartbeat(
self.server_id, "unhealthy", self.load
)
if self.service_config.graceful_shutdown_on_unhealthy:
self._graceful_exit(1)
break
except Exception as e:
logger.error(f"Health monitor error: {e}")
logger.error(traceback.format_exc())
failures += 1
if (
failures >= max_failures
and self.service_config.graceful_shutdown_on_unhealthy
):
self._graceful_exit(1)
break
time.sleep(self.service_config.health_check_interval)
def _graceful_exit(self, exit_code: int):
"""Clean shutdown and exit"""
if not self.running:
return
logger.info(f"Graceful shutdown initiated (exit code: {exit_code})")
self.running = False
# Cleanup registry
try:
self.registry.unregister_server(self.server_id)
except Exception as e:
logger.warning(f"Registry cleanup failed: {e}")
logger.warning(traceback.format_exc())
# Stop process
if self.process and self.process.poll() is None:
try:
self.process.terminate()
self.process.wait(timeout=5)
logger.info("Server terminated gracefully")
except subprocess.TimeoutExpired:
logger.warning("Force killing server")
try:
self.process.kill()
self.process.wait()
except (ProcessLookupError, OSError):
pass
except Exception as e:
logger.error(f"Process cleanup failed: {e}")
logger.error(traceback.format_exc())
if exit_code != 0:
sys.exit(exit_code)
@dataclass
class LLMServerFactory:
args: TrainingArgs
def make_server(self, server_config: LLMServiceConfig) -> LLMServer:
"""Create an LLM server instance based on the configuration."""
if self.args.rollout.server_backend == "sglang":
from arealite.system.sglang_server import SGLangServer
return SGLangServer(self.args, server_config)
else:
raise ValueError(
f"Unsupported server backend: {self.args.rollout.server_backend}"
)

View File

@ -12,6 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig, TrainingArgs from arealite.api.cli_args import TrainerConfig, TrainingArgs
from realhf.base import constants from realhf.base import constants
if TYPE_CHECKING: if TYPE_CHECKING:
from arealite.system.rollout_controller import RolloutController from arealite.system.rollout_controller import RolloutController

View File

@ -1,21 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
from arealite.api.cli_args import LLMClientConfig, TrainingArgs
from arealite.api.llm_server_api import LLMServiceRegistry
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from arealite.api.llm_client_api import LLMClient
class VLMClient(LLMClient):
"""A client for interacting with VLM servers."""
def __init__(self, args: TrainingArgs, client_config: LLMClientConfig):
super().__init__(args, client_config)
self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name)
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
args.rollout.model_path
)

View File

@ -1,5 +1,7 @@
from typing import Optional from typing import Optional
import transformers import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"] VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset( def get_custom_dataset(
@ -20,10 +22,14 @@ def get_custom_dataset(
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size) return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft": elif "clevr_count_70k" in path and training_type == "sft":
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size) return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl": elif "clevr_count_70k" in path and training_type == "rl":
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_rl_dataset from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size) return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
else: else:
raise ValueError( raise ValueError(

View File

@ -1,5 +1,7 @@
from typing import Dict, List from typing import Dict, List
import torch import torch
from arealite.engine.ppo.actor import PPOActor, PPOActorConfig from arealite.engine.ppo.actor import PPOActor, PPOActorConfig
from arealite.engine.vl_fsdp_engine import VL_FSDPEngine from arealite.engine.vl_fsdp_engine import VL_FSDPEngine

View File

@ -1,6 +1,6 @@
import os import os
import time import time
from typing import Any, Callable, Dict, List, Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -10,30 +10,19 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions, StateDictOptions,
get_model_state_dict, get_model_state_dict,
) )
from transformers import ( from transformers import AutoModelForImageTextToText
AutoModelForImageTextToText,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import ( from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
FinetuneSpec, from arealite.engine.fsdp_engine import FSDPEngine
SaveLoadMeta,
WeightUpdateMeta,
)
from arealite.utils.data import ( from arealite.utils.data import (
MicroBatchList, MicroBatchList,
amend_position_ids, amend_position_ids,
pack_tensor_dict, pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list, pad_mb_list,
reorder_list,
split_padded_tensor_dict_into_mb_list, split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list, unsqueeze_mb_list,
) )
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.fsdp import ( from arealite.utils.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
MixedPrecisionPolicy, MixedPrecisionPolicy,
@ -41,10 +30,9 @@ from arealite.utils.fsdp import (
create_fsdp_device_mesh, create_fsdp_device_mesh,
get_cosine_schedule_with_warmup, get_cosine_schedule_with_warmup,
) )
from realhf.base import logging, name_resolve, names, pkg_version from arealite.utils.model import disable_dropout_in_model
from realhf.api.core.data_api import load_hf_processor_and_tokenizer from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from arealite.engine.fsdp_engine import FSDPEngine from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine") logger = logging.getLogger("FSDPEngine")

View File

@ -1,11 +1,8 @@
import time import time
from arealite.api.cli_args import InferenceEngineConfig from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.io_struct import VLMRequest, VLMResponse
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.api.io_struct import (
VLMRequest,
VLMResponse
)
from arealite.utils.http import arequest_with_retry from arealite.utils.http import arequest_with_retry
from realhf.base import logging, pkg_version from realhf.base import logging, pkg_version

File diff suppressed because one or more lines are too long

View File

@ -1,19 +0,0 @@
# Use a pipeline as a high-level helper
from transformers import AutoProcessor
processor=AutoProcessor.from_pretrained(pretrained_model_name_or_path="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct")
input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 4340, 1657, 3589, 525, 1052, 304, 279, 2168, 30, 151645, 198, 151644, 77091, 198]
decoded_text = processor.tokenizer.decode(input_ids)
print(decoded_text)
# pipe = pipeline("image-text-to-text", model="/storage/openpsi/models/Qwen2.5-VL-32B-Instruct",device_map="auto" )
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image", "url": "output_image.jpg"},
# {"type": "text", "text": "What is shown in the image??"}
# ]
# },
# ]
# print(pipe(text=messages))

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from PIL import Image
import aiohttp import aiohttp
from realhf.base import logging from realhf.base import logging

View File

@ -1,10 +1,11 @@
from io import BytesIO
import base64 import base64
import math import math
from torch import Tensor
from typing import Any, Dict, List, Optional, Union
from PIL.Image import Image as ImageObject
from dataclasses import MISSING from dataclasses import MISSING
from io import BytesIO
from typing import List
from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str: def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:

View File

@ -1,7 +1,7 @@
import getpass import getpass
import os import os
from transformers import PreTrainedTokenizerFast, AutoProcessor from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import SaverConfig from arealite.api.cli_args import SaverConfig
from arealite.api.engine_api import TrainEngine from arealite.api.engine_api import TrainEngine

View File

@ -3,12 +3,14 @@ import uuid
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from transformers import PreTrainedTokenizerFast,AutoProcessor from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import VLMRequest from arealite.api.io_struct import VLMRequest
from arealite.workflow.rlvr import RLVRWorkflow from arealite.utils.data import concat_padded_tensors
from arealite.utils.padding import concat_padded_tensors
from arealite.utils.image import image2base64 from arealite.utils.image import image2base64
from arealite.workflow.rlvr import RLVRWorkflow
class VL_RLVRWorkflow(RLVRWorkflow): class VL_RLVRWorkflow(RLVRWorkflow):
def __init__( def __init__(