mirror of https://github.com/inclusionAI/AReaL
619 lines
22 KiB
Python
619 lines
22 KiB
Python
import asyncio
|
|
import os
|
|
import random
|
|
import shutil
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import Future, ProcessPoolExecutor
|
|
from datetime import datetime
|
|
from queue import Empty, Full, Queue
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
|
|
|
import aiohttp
|
|
import requests
|
|
import torch.distributed as dist
|
|
import uvloop
|
|
from tensordict import TensorDict
|
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
|
|
from arealite.api.cli_args import InferenceEngineConfig
|
|
from arealite.api.engine_api import InferenceEngine
|
|
from arealite.api.io_struct import (
|
|
FinetuneSpec,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
RolloutStat,
|
|
VLMRequest,
|
|
VLMResponse,
|
|
WeightUpdateMeta,
|
|
)
|
|
from arealite.utils.data import concat_padded_tensors
|
|
from arealite.utils.http import arequest_with_retry, get_default_connector
|
|
from realhf.base import logging, name_resolve, names
|
|
|
|
if TYPE_CHECKING:
|
|
from arealite.api.workflow_api import RolloutWorkflow
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
ROLLOUT_POLL_WAIT_TIME = 0.05
|
|
RID_CACHE_SIZE = 128
|
|
|
|
|
|
class RemoteSGLangEngine(InferenceEngine):
|
|
|
|
def __init__(self, config: InferenceEngineConfig):
|
|
config.max_concurrent_rollouts = (
|
|
config.max_concurrent_rollouts or config.consumer_batch_size
|
|
)
|
|
self.config = config
|
|
|
|
self.rid_to_address = {}
|
|
# Maintain the addresses for the recent 128 requests
|
|
self.rid_queue = []
|
|
|
|
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
|
if not self.addresses:
|
|
raise RuntimeError("No configured SGLang servers.")
|
|
logger.info("Waiting for server ready...")
|
|
for addr in self.addresses:
|
|
self._wait_for_server(addr)
|
|
logger.info("Servers are all ready!")
|
|
|
|
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
|
|
|
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
|
self.input_queue = Queue(maxsize=qsize)
|
|
self.output_queue = Queue(maxsize=qsize)
|
|
self.result_cache = []
|
|
|
|
self.exiting = threading.Event()
|
|
self.paused = threading.Event()
|
|
self.lock = threading.Lock()
|
|
|
|
self.rollout_stat = RolloutStat()
|
|
self.distributed_weight_update_initialized = False
|
|
|
|
self._version = 0
|
|
|
|
def _wait_for_server(self, address):
|
|
base_url = f"http://{address}"
|
|
tik = time.time()
|
|
while time.time() - tik < self.config.setup_timeout:
|
|
if self.check_health(base_url):
|
|
return
|
|
time.sleep(1)
|
|
raise RuntimeError("server launch failed")
|
|
|
|
def check_health(self, base_url):
|
|
# Check server endpoint
|
|
try:
|
|
response = requests.get(f"{base_url}/health", timeout=30)
|
|
return response.status_code == 200
|
|
except requests.exceptions.RequestException as e:
|
|
return False
|
|
|
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
|
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
self.executor = ProcessPoolExecutor(max_workers=1)
|
|
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
|
self.rollout_thread.start()
|
|
|
|
def destroy(self):
|
|
self.executor.shutdown()
|
|
self.exiting.set()
|
|
self.rollout_thread.join()
|
|
|
|
def set_version(self, version):
|
|
with self.lock:
|
|
self._version = version
|
|
|
|
def get_version(self):
|
|
with self.lock:
|
|
return self._version
|
|
|
|
def _rollout_thread(self):
|
|
"""Thread that runs the rollout loop."""
|
|
try:
|
|
uvloop.run(self._rollout_thread_async())
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
async def _rollout_thread_async(self):
|
|
rollout_tasks = self.rollout_tasks
|
|
rid = 0
|
|
|
|
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
|
|
self.session = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self.config.request_timeout,
|
|
sock_connect=self.config.request_timeout,
|
|
connect=self.config.request_timeout,
|
|
),
|
|
read_bufsize=1024 * 1024 * 10,
|
|
connector=get_default_connector(),
|
|
)
|
|
|
|
try:
|
|
while not self.exiting.is_set():
|
|
# Check capacity
|
|
capacity = self.get_capacity()
|
|
# Create new rollout task
|
|
while (
|
|
capacity > 0
|
|
and not self.paused.is_set()
|
|
and self.input_queue.qsize() > 0
|
|
):
|
|
data, workflow = self.input_queue.get_nowait()
|
|
logger.debug(f"Get data from puller: {data}")
|
|
task = asyncio.create_task(
|
|
workflow.arun_episode(self, data), name=str(rid)
|
|
)
|
|
with self.lock:
|
|
rollout_tasks[str(rid)] = task
|
|
self.rollout_stat.submitted += 1
|
|
self.rollout_stat.running += 1
|
|
if self.config.enable_rollout_tracing:
|
|
logger.info(
|
|
f"Submit rollout rid {rid}. "
|
|
f"Submit: {self.rollout_stat.submitted}, "
|
|
f"running: {self.rollout_stat.running}, "
|
|
f"accepted: {self.rollout_stat.accepted}."
|
|
)
|
|
capacity -= 1
|
|
rid += 1
|
|
# Wait for rollout completion
|
|
with self.lock:
|
|
tasks = list(rollout_tasks.values())
|
|
done = []
|
|
if tasks:
|
|
done, _ = await asyncio.wait(
|
|
tasks,
|
|
timeout=ROLLOUT_POLL_WAIT_TIME,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
# Collect done results
|
|
for task in done:
|
|
traj = await task
|
|
traj: TensorDict
|
|
task_rid = task.get_name()
|
|
with self.lock:
|
|
rollout_tasks.pop(task_rid)
|
|
self.rollout_stat.accepted += 1
|
|
|
|
try:
|
|
self.output_queue.put_nowait(traj)
|
|
except Full:
|
|
raise RuntimeError(
|
|
"Output queue full. Please increase queue_size."
|
|
)
|
|
|
|
with self.lock:
|
|
self.rollout_stat.running -= 1
|
|
if self.config.enable_rollout_tracing:
|
|
logger.info(
|
|
f"Finish rollout {task_rid}. "
|
|
f"Submit: {self.rollout_stat.submitted}, "
|
|
f"running: {self.rollout_stat.running}, "
|
|
f"accepted: {self.rollout_stat.accepted}."
|
|
)
|
|
await asyncio.sleep(1)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
finally:
|
|
# Cancel remaining tasks
|
|
with self.lock:
|
|
for task in rollout_tasks.values():
|
|
if not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
def choose_server(self) -> str:
|
|
with self.lock:
|
|
if self.config.schedule_policy == "round_robin":
|
|
server = self.addresses[self.server_idx]
|
|
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
|
return server
|
|
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
|
|
|
async def agenerate(
|
|
self, req: LLMRequest | VLMRequest
|
|
) -> LLMResponse | VLMResponse:
|
|
"""Async version of generate using aiohttp."""
|
|
# Prepare request payload
|
|
gconfig = req.gconfig
|
|
stop_token_ids = gconfig.stop_token_ids
|
|
|
|
if gconfig.n_samples != 1:
|
|
raise ValueError(
|
|
"RemoteSGLangEngine does not support n_samples > 1. "
|
|
"Please call generate for multiple times with n_samples = 1."
|
|
)
|
|
sample_params = {
|
|
"top_p": gconfig.top_p,
|
|
"top_k": gconfig.top_k,
|
|
"max_new_tokens": gconfig.max_new_tokens,
|
|
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
|
"stop_token_ids": stop_token_ids,
|
|
}
|
|
if isinstance(req, VLMRequest):
|
|
# VLMRequest has image_data
|
|
payload = {
|
|
"input_ids": req.input_ids.copy(),
|
|
"image_data": req.image_data, # ImageObject or str
|
|
"sampling_params": sample_params,
|
|
"return_logprob": True,
|
|
"stream": False,
|
|
}
|
|
else:
|
|
# NOTE: rid should NOT be passed in payload
|
|
payload = {
|
|
"input_ids": req.input_ids.copy(),
|
|
"sampling_params": sample_params,
|
|
"return_logprob": True,
|
|
"stream": False,
|
|
}
|
|
|
|
# Make request
|
|
start_time = time.perf_counter()
|
|
accumulated_output_tokens = []
|
|
accumulated_output_logprobs = []
|
|
accumulated_versions = []
|
|
|
|
# A single "rid" shares the same sever to allow KV cache reuse
|
|
if req.rid in self.rid_to_address:
|
|
server_addr = self.rid_to_address[req.rid]
|
|
else:
|
|
server_addr = self.choose_server()
|
|
if len(self.rid_queue) >= RID_CACHE_SIZE:
|
|
# Remove the oldest entry if cache is full
|
|
oldest_rid = self.rid_queue.pop(0)
|
|
self.rid_to_address.pop(oldest_rid, None)
|
|
self.rid_to_address[req.rid] = server_addr
|
|
self.rid_queue.append(req.rid)
|
|
|
|
# Deal with rollout interruption
|
|
# "abort" is the stop reason for later v0.4.9.post2 after
|
|
# we call the pause_generation endpoint
|
|
stop_reason = None
|
|
while (
|
|
stop_reason != "stop"
|
|
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
|
):
|
|
# Request is interrupted, wait for some time to avoid interfering
|
|
# with update weights requests
|
|
if stop_reason is not None:
|
|
await asyncio.sleep(0.5)
|
|
|
|
# loop until the generation is complete
|
|
result = await arequest_with_retry(
|
|
session=self.session,
|
|
addr=server_addr,
|
|
endpoint="/generate",
|
|
payload=payload,
|
|
method="POST",
|
|
max_retries=self.config.request_retries,
|
|
timeout=self.config.request_timeout,
|
|
)
|
|
|
|
meta_info = result["meta_info"]
|
|
# Check if generation is complete
|
|
finish_reason = meta_info["finish_reason"]
|
|
stop_reason = finish_reason["type"]
|
|
if (
|
|
stop_reason == "abort"
|
|
and finish_reason.get("message") == "Abort before prefill"
|
|
):
|
|
continue
|
|
|
|
# Parse response
|
|
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
|
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
|
|
|
# Update accumulated outputs
|
|
accumulated_output_tokens.extend(output_tokens)
|
|
accumulated_output_logprobs.extend(output_logprobs)
|
|
# FIXME: Update with actual server versions
|
|
accumulated_versions.extend([-1] * len(output_tokens))
|
|
|
|
payload["input_ids"] += result["output_ids"]
|
|
sample_params["max_new_tokens"] -= len(output_tokens)
|
|
|
|
latency = time.perf_counter() - start_time
|
|
|
|
if isinstance(req, VLMRequest):
|
|
response = VLMResponse(
|
|
input_tokens=req.input_ids,
|
|
input_images=req.image_data,
|
|
output_tokens=accumulated_output_tokens,
|
|
output_logprobs=accumulated_output_logprobs,
|
|
output_versions=accumulated_versions,
|
|
stop_reason=stop_reason,
|
|
latency=latency,
|
|
ttft=latency, # Simplified for non-streaming
|
|
)
|
|
else:
|
|
response = LLMResponse(
|
|
input_tokens=req.input_ids,
|
|
output_tokens=accumulated_output_tokens,
|
|
output_logprobs=accumulated_output_logprobs,
|
|
output_versions=accumulated_versions,
|
|
stop_reason=stop_reason,
|
|
latency=latency,
|
|
ttft=latency, # Simplified for non-streaming
|
|
)
|
|
return response
|
|
|
|
def update_weights(self, meta: WeightUpdateMeta):
|
|
for addr in self.addresses:
|
|
res = requests.post(f"http://{addr}/pause_generation")
|
|
res.raise_for_status()
|
|
fut = Future()
|
|
if meta.type == "nccl":
|
|
fut = self.executor.submit(
|
|
update_weights_from_distributed,
|
|
meta,
|
|
self.addresses,
|
|
self.config.request_timeout,
|
|
not self.distributed_weight_update_initialized,
|
|
)
|
|
|
|
def callback(fut):
|
|
self.distributed_weight_update_initialized = True
|
|
|
|
fut.add_done_callback(callback)
|
|
elif meta.type == "disk":
|
|
# Update weights from disk
|
|
# Use ProcessPool to bypass python GIL for running async coroutines
|
|
fut = self.executor.submit(
|
|
update_weights_from_disk,
|
|
self.config.experiment_name,
|
|
self.config.trial_name,
|
|
self.get_version(),
|
|
self.addresses,
|
|
meta.path,
|
|
self.config.request_retries,
|
|
self.config.request_timeout,
|
|
)
|
|
|
|
def callback(fut):
|
|
shutil.rmtree(meta.path, ignore_errors=True)
|
|
|
|
fut.add_done_callback(callback)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
|
|
|
def callback(fut):
|
|
for addr in self.addresses:
|
|
res = requests.post(f"http://{addr}/continue_generation")
|
|
res.raise_for_status()
|
|
|
|
fut.add_done_callback(callback)
|
|
return fut
|
|
|
|
def get_capacity(self):
|
|
if dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
world_size = 1
|
|
|
|
max_concurrent_rollouts = max(
|
|
1, self.config.max_concurrent_rollouts // world_size
|
|
)
|
|
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
|
# Staleness control
|
|
version = self.get_version()
|
|
ofp = self.config.max_head_offpolicyness
|
|
with self.lock:
|
|
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
|
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
|
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
|
return capacity
|
|
|
|
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
|
try:
|
|
self.input_queue.put_nowait((data, workflow))
|
|
except Full:
|
|
raise RuntimeError("Input queue full. Please increase queue_size.")
|
|
|
|
def wait(
|
|
self,
|
|
count: int,
|
|
timeout: float | None = None,
|
|
should_accept: Callable | None = None,
|
|
) -> TensorDict:
|
|
tik = time.perf_counter()
|
|
accepted = len(self.result_cache)
|
|
timeout = timeout or float(7 * 24 * 3600)
|
|
while (
|
|
accepted < count
|
|
and not self.exiting.is_set()
|
|
and time.perf_counter() - tik < timeout
|
|
):
|
|
try:
|
|
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
|
if should_accept is None or should_accept(result):
|
|
self.result_cache.append(result)
|
|
accepted += 1
|
|
else:
|
|
with self.lock:
|
|
self.rollout_stat.accepted -= 1
|
|
except Empty:
|
|
pass
|
|
if self.exiting.is_set():
|
|
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
|
if accepted < count:
|
|
raise TimeoutError(
|
|
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
|
|
)
|
|
results, self.result_cache = (
|
|
self.result_cache[:count],
|
|
self.result_cache[count:],
|
|
)
|
|
return concat_padded_tensors(results)
|
|
|
|
def rollout_batch(
|
|
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
|
) -> TensorDict:
|
|
"""Submit a batch of requests to the inference engine and wait for the results."""
|
|
for item in data:
|
|
self.submit(item, workflow)
|
|
return self.wait(count=len(data))
|
|
|
|
def prepare_batch(
|
|
self,
|
|
dataloader: StatefulDataLoader,
|
|
workflow: "RolloutWorkflow",
|
|
):
|
|
if not hasattr(self, "data_generator"):
|
|
self.data_generator = iter(dataloader)
|
|
assert dataloader.batch_size is not None
|
|
while True:
|
|
# Submit at least two batches to allow maximum overlap
|
|
if (
|
|
self.get_capacity() + dataloader.batch_size > 0
|
|
and self.input_queue.qsize() + dataloader.batch_size
|
|
< self.input_queue.maxsize
|
|
):
|
|
try:
|
|
data = next(self.data_generator)
|
|
|
|
except StopIteration:
|
|
self.data_generator = iter(dataloader)
|
|
data = next(self.data_generator)
|
|
for item in data:
|
|
self.submit(item, workflow=workflow)
|
|
try:
|
|
return self.wait(dataloader.batch_size, timeout=1)
|
|
except TimeoutError:
|
|
pass
|
|
|
|
def pause(self):
|
|
self.paused.set()
|
|
|
|
def resume(self):
|
|
self.paused.clear()
|
|
|
|
|
|
def update_weights_from_disk(
|
|
experiment_name,
|
|
trial_name,
|
|
model_version,
|
|
addresses,
|
|
path,
|
|
request_retries,
|
|
request_timeout,
|
|
):
|
|
async def _fn():
|
|
# Wait for model checkpoints of meta.version
|
|
update_name = names.update_weights_from_disk(
|
|
experiment_name, trial_name, model_version
|
|
)
|
|
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
|
|
load_timestamp = datetime.now().timestamp()
|
|
logger.info(
|
|
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
|
)
|
|
session = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=request_timeout,
|
|
sock_connect=request_timeout,
|
|
connect=request_timeout,
|
|
),
|
|
read_bufsize=1024 * 1024 * 10,
|
|
connector=get_default_connector(),
|
|
)
|
|
jobs = [
|
|
arequest_with_retry(
|
|
addr=addr,
|
|
session=session,
|
|
endpoint="/update_weights_from_disk",
|
|
payload=dict(model_path=str(path)),
|
|
method="POST",
|
|
max_retries=request_retries,
|
|
timeout=request_timeout,
|
|
)
|
|
for addr in addresses
|
|
]
|
|
await asyncio.gather(*jobs)
|
|
await session.close()
|
|
logger.info(
|
|
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
|
)
|
|
|
|
return uvloop.run(_fn())
|
|
|
|
|
|
def update_weights_from_distributed(
|
|
meta: WeightUpdateMeta,
|
|
addresses: List[str],
|
|
request_timeout,
|
|
init_group: bool,
|
|
):
|
|
async def _fn():
|
|
tik = time.perf_counter()
|
|
if init_group:
|
|
await asyncio.gather(
|
|
*[
|
|
ainit_weights_update_group(addr, i, meta, request_timeout)
|
|
for i, addr in enumerate(addresses)
|
|
]
|
|
)
|
|
await asyncio.gather(
|
|
*[
|
|
arequest_with_retry(
|
|
addr=addr,
|
|
endpoint="/update_weights_from_distributed",
|
|
payload={
|
|
"names": [pspec.name for pspec in meta.nccl_param_specs],
|
|
"dtypes": [pspec.dtype for pspec in meta.nccl_param_specs],
|
|
"shapes": [pspec.shape for pspec in meta.nccl_param_specs],
|
|
"group_name": meta.nccl_group_name,
|
|
},
|
|
method="POST",
|
|
max_retries=1,
|
|
timeout=request_timeout,
|
|
)
|
|
for addr in addresses
|
|
]
|
|
)
|
|
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
|
|
|
return uvloop.run(_fn())
|
|
|
|
|
|
async def ainit_weights_update_group(
|
|
addr: str,
|
|
server_idx: int,
|
|
meta: WeightUpdateMeta,
|
|
request_timeout: float,
|
|
):
|
|
assert meta.alloc_mode is not None
|
|
if meta.alloc_mode.gen_pp_size != 1:
|
|
raise NotImplementedError(
|
|
"NCCL weight update with PP size > 1 is not implemented yet."
|
|
)
|
|
rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size
|
|
payload = {
|
|
"master_address": meta.nccl_master_address,
|
|
"master_port": str(meta.nccl_master_port),
|
|
"rank_offset": rank_offset,
|
|
"world_size": meta.alloc_mode.gen_world_size + 1,
|
|
"backend": "nccl",
|
|
"group_name": meta.nccl_group_name,
|
|
}
|
|
res = await arequest_with_retry(
|
|
addr=addr,
|
|
endpoint="/init_weights_update_group",
|
|
payload=payload,
|
|
method="POST",
|
|
max_retries=1,
|
|
timeout=request_timeout,
|
|
)
|
|
assert res["success"]
|