mirror of https://github.com/inclusionAI/AReaL
290 lines
10 KiB
Python
290 lines
10 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import asyncio
|
|
import time
|
|
from asyncio.queues import QueueEmpty
|
|
from collections import defaultdict
|
|
from dataclasses import asdict
|
|
from typing import Dict, Hashable, List
|
|
|
|
import aiohttp
|
|
from aiohttp.client import ClientTimeout
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
from realhf.api.cli_args import GenerationHyperparameters
|
|
from realhf.api.core.model_api import (
|
|
APIGenerateInput,
|
|
APIGenerateOutput,
|
|
BundledGenerationOutputs,
|
|
GenReqMeta,
|
|
)
|
|
from realhf.base import constants, logging, name_resolve, names
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GENERATION_POLL_WAIT_TIME = 0.05
|
|
|
|
|
|
class PartialRolloutManager:
|
|
"""Manages the partial rollout for a client.
|
|
|
|
It will submit generation requests in chunks, i.e.,
|
|
generating at most `new_tokens_per_chunk` tokens each time.
|
|
In this way, we can reduce the overhead of flushing all requests
|
|
upon model weights update.
|
|
|
|
This is a hack usage. We don't need it if the server can pause
|
|
requests, update weights, and recompute kv caches at any time.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
worker_index: int,
|
|
request_queue: asyncio.Queue,
|
|
reply_queue: asyncio.Queue,
|
|
new_tokens_per_chunk: int,
|
|
tokenizer: PreTrainedTokenizerFast,
|
|
timeout: int,
|
|
):
|
|
self.worker_index = worker_index
|
|
|
|
# qid -> {group_idx -> aiohttp Task}
|
|
self.gen_requests: Dict[Hashable, Dict[int, asyncio.Task]]
|
|
self.gen_requests = defaultdict(dict)
|
|
|
|
# NOTE: Grouped generations are managed separately. Store early returned
|
|
# answers in this cache and pop the result when the whole group is done.
|
|
self.gen_cache: Dict[Hashable, Dict[int, APIGenerateOutput]]
|
|
self.gen_cache = defaultdict(dict)
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.request_queue = request_queue
|
|
self.reply_queue = reply_queue
|
|
|
|
self.new_tokens_per_chunk = new_tokens_per_chunk
|
|
|
|
self.gserver_manager_addr = None
|
|
self.timeout = timeout
|
|
|
|
async def _schedule_request(self, req_meta: GenReqMeta):
|
|
if self.gserver_manager_addr is None:
|
|
# Get the address of gserver manager to schedule requests
|
|
name = names.gen_server_manager(
|
|
constants.experiment_name(), constants.trial_name()
|
|
)
|
|
self.gserver_manager_addr = name_resolve.wait(name, timeout=300)
|
|
time.sleep(1) # Wait for the server to start
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"http://{self.gserver_manager_addr}/schedule_request",
|
|
json=asdict(req_meta),
|
|
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
|
|
) as response:
|
|
response.raise_for_status()
|
|
res = await response.json()
|
|
return res
|
|
|
|
def get_num_gen_requests(self):
|
|
return len(self.gen_requests)
|
|
|
|
async def _run_gen(
|
|
self,
|
|
url,
|
|
qid,
|
|
group_idx,
|
|
prompt_ids,
|
|
input_ids,
|
|
prev_logprobs,
|
|
version_start,
|
|
cur_server_version,
|
|
raw_gconfig,
|
|
):
|
|
from realhf.impl.model.backend.sglang import SGLangAPIClient
|
|
|
|
max_new_tokens = min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk)
|
|
max_new_tokens = min(
|
|
max_new_tokens,
|
|
raw_gconfig.max_new_tokens - len(input_ids) + len(prompt_ids),
|
|
)
|
|
gconfig = raw_gconfig.new(
|
|
n=1,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
assert self.tokenizer.pad_token_id is not None
|
|
assert self.tokenizer.eos_token_id is not None
|
|
# Don't need to request updating weights
|
|
async with SGLangAPIClient(
|
|
generate_url=f"{url}/generate", update_weights_url=""
|
|
) as api_client:
|
|
res = await api_client.async_add_generate_request(
|
|
APIGenerateInput(
|
|
qid=qid,
|
|
prompt_ids=prompt_ids,
|
|
input_ids=input_ids,
|
|
gconfig=gconfig,
|
|
stop_token_ids=[
|
|
self.tokenizer.pad_token_id,
|
|
self.tokenizer.eos_token_id,
|
|
],
|
|
return_logprob=True,
|
|
version_start=version_start,
|
|
prev_logprobs=prev_logprobs,
|
|
metadata=dict(
|
|
group_idx=group_idx,
|
|
raw_gconfig=raw_gconfig,
|
|
server_url=url,
|
|
version=cur_server_version,
|
|
),
|
|
),
|
|
stream=False,
|
|
)
|
|
res.version_end = [cur_server_version for _ in range(res.group_size)]
|
|
return res
|
|
|
|
async def _issue_generation(
|
|
self,
|
|
url: str,
|
|
qid: Hashable,
|
|
group_idx: int,
|
|
prompt_ids: List[int],
|
|
input_ids: List[int],
|
|
prev_logprobs: List[float],
|
|
version_start: int,
|
|
raw_gconfig: GenerationHyperparameters,
|
|
cur_server_version: int,
|
|
):
|
|
"""Issue a generation request.
|
|
|
|
`input_ids` can be a partial prefix and longer than `prompt_ids`.
|
|
If model weights are updated, the KV cache will be refreshed,
|
|
otherwise the server will reuse the radix cache with no additional overhead.
|
|
"""
|
|
|
|
task = asyncio.create_task(
|
|
self._run_gen(
|
|
url,
|
|
qid,
|
|
group_idx,
|
|
prompt_ids,
|
|
input_ids,
|
|
prev_logprobs,
|
|
version_start=version_start,
|
|
cur_server_version=cur_server_version,
|
|
raw_gconfig=raw_gconfig,
|
|
)
|
|
)
|
|
self.gen_requests[qid][group_idx] = task
|
|
await asyncio.sleep(0)
|
|
|
|
async def refresh_generation(self):
|
|
tasks = []
|
|
for group_requests in self.gen_requests.values():
|
|
tasks += list(group_requests.values())
|
|
|
|
done = []
|
|
if tasks:
|
|
# No new checkpoint available, try to wait for the next complete sequence
|
|
done, _ = await asyncio.wait(
|
|
tasks,
|
|
timeout=GENERATION_POLL_WAIT_TIME,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
for task in done:
|
|
s: APIGenerateOutput = await task
|
|
group_idx = s.metadata["group_idx"]
|
|
raw_gconfig = s.metadata["raw_gconfig"]
|
|
previous_version = s.metadata["version"]
|
|
|
|
assert s.group_size == 1
|
|
no_eos = s.no_eos[0]
|
|
gen_len = s.gen_lens[0]
|
|
|
|
self.gen_requests[s.qid].pop(group_idx)
|
|
if len(self.gen_requests[s.qid]) == 0:
|
|
self.gen_requests.pop(s.qid)
|
|
|
|
if no_eos and gen_len < raw_gconfig.max_new_tokens:
|
|
# Unfinished request due to chunked generation.
|
|
# Send it back to continue.
|
|
req_meta = GenReqMeta(
|
|
qid=s.qid,
|
|
prompt_len=s.prompt_len,
|
|
group_size=raw_gconfig.n,
|
|
new_token_budget=raw_gconfig.max_new_tokens,
|
|
predicted_new_tokens=None,
|
|
previous_server_url=s.metadata["server_url"],
|
|
previous_version=previous_version,
|
|
)
|
|
info = await self._schedule_request(req_meta)
|
|
cur_version = info["version"]
|
|
server_url = info["url"]
|
|
|
|
if len(s.output_logprobs) > 0:
|
|
prev_logprobs = s.prev_logprobs + s.output_logprobs[0]
|
|
else:
|
|
prev_logprobs = s.prev_logprobs
|
|
if prev_logprobs is None:
|
|
prev_logprobs = []
|
|
await self._issue_generation(
|
|
server_url,
|
|
s.qid,
|
|
group_idx,
|
|
s.prompt_ids,
|
|
s.input_ids + s.output_ids[0],
|
|
version_start=s.version_start,
|
|
prev_logprobs=prev_logprobs,
|
|
raw_gconfig=raw_gconfig,
|
|
cur_server_version=cur_version,
|
|
)
|
|
else:
|
|
# Generation finishes. Save to cache for later fetching.
|
|
self.gen_cache[s.qid][group_idx] = s
|
|
if len(self.gen_cache[s.qid]) >= raw_gconfig.n:
|
|
gen_results = self.gen_cache.pop(s.qid)
|
|
output = BundledGenerationOutputs.from_api_outputs(
|
|
list(gen_results.values())
|
|
)
|
|
self.reply_queue.put_nowait(output)
|
|
|
|
async def poll_fresh_requests_task(self):
|
|
for _ in range(8):
|
|
try:
|
|
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
|
|
req_meta = GenReqMeta(
|
|
qid=qid,
|
|
prompt_len=len(prompt_token_ids),
|
|
group_size=gconfig.n,
|
|
new_token_budget=gconfig.max_new_tokens,
|
|
predicted_new_tokens=None,
|
|
)
|
|
dst_server_info = await self._schedule_request(req_meta)
|
|
|
|
for group_idx in range(gconfig.n):
|
|
await self._issue_generation(
|
|
dst_server_info["url"],
|
|
qid,
|
|
group_idx,
|
|
prompt_token_ids,
|
|
prompt_token_ids,
|
|
version_start=dst_server_info["version"],
|
|
prev_logprobs=[],
|
|
raw_gconfig=gconfig,
|
|
cur_server_version=dst_server_info["version"],
|
|
)
|
|
except QueueEmpty:
|
|
break
|
|
|
|
async def poll_old_requests_task(self):
|
|
for _ in range(8):
|
|
await self.refresh_generation()
|
|
|
|
async def run_step(self):
|
|
|
|
await asyncio.gather(
|
|
self.poll_fresh_requests_task(),
|
|
self.poll_old_requests_task(),
|
|
)
|