mirror of https://github.com/inclusionAI/AReaL
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
This commit is contained in:
commit
7ab6755379
|
@ -0,0 +1,50 @@
|
|||
name: Test AReaLite
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/test-arealite.yml
|
||||
- arealite/**
|
||||
- ci/**
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test-arealite:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: test-arealite
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: appleboy/ssh-action@v1
|
||||
env:
|
||||
GIT_REPO_URL: https://github.bibk.top/${{ github.repository }}
|
||||
GIT_COMMIT_SHA: ${{ github.sha }}
|
||||
with:
|
||||
host: ${{ secrets.CI_NODE_ADDR }}
|
||||
username: ${{ secrets.CI_NODE_USER }}
|
||||
key: ${{ secrets.REMOTE_SSH_KEY }}
|
||||
envs: GIT_REPO_URL,GIT_COMMIT_SHA
|
||||
script_path: ci/clone_repo.sh
|
||||
|
||||
- uses: appleboy/ssh-action@v1
|
||||
env:
|
||||
GIT_COMMIT_SHA: ${{ github.sha }}
|
||||
with:
|
||||
host: ${{ secrets.CI_NODE_ADDR }}
|
||||
username: ${{ secrets.CI_NODE_USER }}
|
||||
key: ${{ secrets.REMOTE_SSH_KEY }}
|
||||
command_timeout: 2h
|
||||
envs: GIT_COMMIT_SHA
|
||||
script_path: ci/build_env_image.sh
|
||||
|
||||
- uses: appleboy/ssh-action@v1
|
||||
env:
|
||||
GIT_COMMIT_SHA: ${{ github.sha }}
|
||||
with:
|
||||
host: ${{ secrets.CI_NODE_ADDR }}
|
||||
username: ${{ secrets.CI_NODE_USER }}
|
||||
key: ${{ secrets.REMOTE_SSH_KEY }}
|
||||
command_timeout: 1h
|
||||
envs: GIT_COMMIT_SHA
|
||||
script_path: ci/test_arealite.sh
|
|
@ -68,3 +68,23 @@ class GenerationHyperparameters:
|
|||
args = asdict(self)
|
||||
args.update(kwargs)
|
||||
return GenerationHyperparameters(**args)
|
||||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
# Used by remote inference engines.
|
||||
server_addrs: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of server addresses for inference."}
|
||||
)
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
)
|
||||
request_timeout: float = field(
|
||||
default=30.0,
|
||||
metadata={"help": "Timeout for HTTP requests."}
|
||||
)
|
||||
request_retries: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of retries for failed requests."}
|
||||
)
|
|
@ -1,4 +1,7 @@
|
|||
import abc
|
||||
from typing import Callable, Dict, List, Any, Optional, TYPE_CHECKING
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
@ -13,6 +16,9 @@ from arealite.api.io_struct import (
|
|||
SaveLoadMeta,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from tensordict import TensorDict
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -113,14 +119,14 @@ class InferenceEngine(abc.ABC):
|
|||
"""Asynchronously generate a response for the given request."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow) -> None:
|
||||
def submit(self, data: Dict[str, Any], workflow:"RolloutWorkflow") -> None:
|
||||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: int) -> Any:
|
||||
def wait(self, count: int, timeout: int) -> TensorDict:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(self, data: List[Dict[str, Any]], workflow) -> Any:
|
||||
def rollout(self, data: List[Dict[str, Any]], workflow:"RolloutWorkflow") -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -170,3 +170,9 @@ class SaveLoadMeta:
|
|||
with_optim: bool
|
||||
tokenizer: PreTrainedTokenizerFast | None
|
||||
base_model_path: str | None
|
||||
|
||||
@dataclass
|
||||
class RolloutStat:
|
||||
submitted: int = 0
|
||||
accepted: int = 0
|
||||
running: int = 0
|
|
@ -0,0 +1,315 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
EngineConfig,
|
||||
MicroBatchSpec,
|
||||
ParallelismConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.utils import (
|
||||
get_state_dict_from_repo_id_or_path,
|
||||
recorder_list,
|
||||
split_dict_tensor_with_cu_seqlens,
|
||||
unpack_sequence,
|
||||
)
|
||||
from realhf.base import constants
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float = 0.0,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The minimum lr ratio w.r.t the maximum.
|
||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
|
||||
coef = (1 - min_lr_ratio) * 0.5
|
||||
intercept = (1 + min_lr_ratio) * 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
||||
return max(0.0, x * coef + intercept)
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class HFEngine(TrainEngine):
|
||||
"""Simplified HF engine for transformer models."""
|
||||
|
||||
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
|
||||
super().__init__(args, engine_config)
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.model_config = None
|
||||
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
|
||||
"""Initialize model in single node."""
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
if dist.get_world_size() > 1:
|
||||
raise RuntimeError(
|
||||
"Distributed training is not supported in this engine. "
|
||||
"Please use FSDP for distributed training."
|
||||
)
|
||||
torch.cuda.set_device("cuda:0")
|
||||
|
||||
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.engine_config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
self.model = model
|
||||
|
||||
# Set up optimizer
|
||||
optimizer_config = self.engine_config.optimizer
|
||||
if optimizer_config is not None:
|
||||
assert (
|
||||
optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = optimizer_config.lr
|
||||
weight_decay = optimizer_config.weight_decay
|
||||
beta1 = optimizer_config.beta1
|
||||
beta2 = optimizer_config.beta2
|
||||
eps = optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=optimizer_config.min_lr_ratio,
|
||||
)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the module in training mode."""
|
||||
return self.model.train(mode)
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
assert self.optimizer is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.engine_config.optimizer.gradient_clipping,
|
||||
norm_type=2.0,
|
||||
error_if_nonfinite=False,
|
||||
foreach=None,
|
||||
)
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
return {
|
||||
"grad_norm": grad_norm,
|
||||
"lr": current_lr,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
if output_seqlens is None:
|
||||
cu_seqlens = input_["cu_seqlens"]
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
if post_hook:
|
||||
result = post_hook(outputs.logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(outputs.logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
|
||||
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
"""Step the learning rate scheduler."""
|
||||
return self.lr_scheduler.step()
|
||||
|
||||
def save_model_to_hf(
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
base_model_path: Optional[str] = None,
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
def load_model_from_hf(self, path: str):
|
||||
"""Load model from HuggingFace format."""
|
||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||
self.model.load_state_dict(
|
||||
full_state, strict=not self.model_config.tie_word_embeddings
|
||||
)
|
||||
if self.model_config.tie_word_embeddings:
|
||||
self.model.tie_weights()
|
||||
|
||||
def save_optimizer_state(self, path: str):
|
||||
"""Save optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
|
||||
|
||||
def load_optimizer_state(self, path: str):
|
||||
"""Load optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
optimizer_path = os.path.join(path, "optimizer.pt")
|
||||
if os.path.exists(optimizer_path):
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(optimizer_path, map_location="cpu")
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
|
||||
|
||||
async def aupdate_weights_to(self, llm_client: LLMClient):
|
||||
path = constants.get_param_realloc_path(self.args)
|
||||
self.save_model_to_hf(path)
|
||||
tasks = [
|
||||
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
|
||||
for server_info in llm_client.get_healthy_servers()
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def update_weights_to(self, llm_client: LLMClient):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.aupdate_weights_to(llm_client))
|
||||
finally:
|
||||
loop.close()
|
|
@ -0,0 +1,398 @@
|
|||
import time
|
||||
|
||||
from arealite.api.io_struct import (
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
WeightUpdateMeta,
|
||||
RolloutStat,
|
||||
)
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from realhf.base import logging, pkg_version
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from tensordict import TensorDict
|
||||
from typing import Dict, Any, Optional, TYPE_CHECKING, Callable
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue, Empty
|
||||
import torch.distributed as dist
|
||||
import traceback
|
||||
import threading
|
||||
from realhf.base import name_resolve, names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if pkg_version.is_available("sglang"):
|
||||
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
|
||||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
|
||||
|
||||
class RemoteSGLangEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, config: InferenceEngineConfig):
|
||||
self.config = config
|
||||
|
||||
self.rid_to_address = {}
|
||||
self.addresses = config.server_addrs
|
||||
self.server_idx = 0
|
||||
|
||||
self.input_queue = Queue(maxsize=config.max_concurrent_rollouts)
|
||||
self.output_queue = Queue(maxsize=config.max_concurrent_rollouts)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
def _get_model_version(self) -> int:
|
||||
name = names.model_version(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
"actor",
|
||||
)
|
||||
try:
|
||||
return int(name_resolve.get(name))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
return 0
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
def _rollout_thread(self):
|
||||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._rollout_thread_async())
|
||||
finally:
|
||||
self.exiting.set()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
|
||||
|
||||
rollout_tasks: Dict[int, asyncio.Task] = {}
|
||||
rid = 0
|
||||
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Load next data from controller
|
||||
if data is None:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
|
||||
# Check capacity
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
cannot_rollout_reason = []
|
||||
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
|
||||
can_rollout = len(rollout_tasks) < capacity
|
||||
if not can_rollout:
|
||||
cannot_rollout_reason.append(
|
||||
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
|
||||
)
|
||||
|
||||
# Staleness control
|
||||
version = self._get_model_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
expected_version = sample_cnt // self.train_batch_size
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
||||
if not can_rollout:
|
||||
logger.debug(
|
||||
f"Cannot submit new rollouts. "
|
||||
+ "\n".join(cannot_rollout_reason)
|
||||
)
|
||||
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None:
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
)
|
||||
rollout_tasks[rid] = task
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.debug(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
rid += 1
|
||||
data = None
|
||||
|
||||
# Wait for rollout completion
|
||||
tasks = list(rollout_tasks.values())
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
tasks,
|
||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
else:
|
||||
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
rollout_tasks.pop(task_rid)
|
||||
|
||||
self.output_queue.put(traj)
|
||||
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
logger.debug(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def choose_server(self) -> str:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
server = self.addresses[self.server_idx]
|
||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def arequest_with_retry(
|
||||
self,
|
||||
endpoint: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
method: str = "POST",
|
||||
max_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
retry_delay: float = 1.0,
|
||||
target_addr: Optional[str] = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
timeout = timeout or self.config.request_timeout
|
||||
last_exception = None
|
||||
max_retries = max_retries or self.config.request_retries
|
||||
|
||||
# Try with retries
|
||||
for _ in range(max_retries):
|
||||
if target_addr:
|
||||
addr = target_addr
|
||||
else:
|
||||
addr = self.choose_server()
|
||||
base_url = f"http://{addr}"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=30,
|
||||
sock_read=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
response = await session.get(url)
|
||||
elif method.upper() == "POST":
|
||||
response = await session.post(url, json=payload)
|
||||
elif method.upper() == "PUT":
|
||||
response = await session.put(url, json=payload)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await session.delete(url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientResponseError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
||||
)
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
stop_token_ids = gconfig.stop_token_ids
|
||||
|
||||
assert gconfig.n_samples == 1
|
||||
sample_params = {
|
||||
"top_p": gconfig.top_p,
|
||||
"top_k": gconfig.top_k,
|
||||
"max_new_tokens": gconfig.max_new_tokens,
|
||||
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
||||
"stop_token_ids": stop_token_ids,
|
||||
}
|
||||
|
||||
payload = {
|
||||
"rid": req.rid,
|
||||
"text": req.text,
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
if req.text:
|
||||
payload["text"] = req.text
|
||||
else:
|
||||
payload["input_ids"] = req.input_ids
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
accumulated_output_tokens = []
|
||||
accumulated_output_logprobs = []
|
||||
accumulated_versions = []
|
||||
|
||||
# Deal with rollout interruption
|
||||
completions = ""
|
||||
stop_reason = "length"
|
||||
|
||||
while (
|
||||
stop_reason != "stop"
|
||||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# loop until the generation is complete
|
||||
response = await self.arequest_with_retry(
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
# Parse response
|
||||
completions += result["text"]
|
||||
meta_info = result["meta_info"]
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
||||
# Update accumulated outputs
|
||||
accumulated_output_tokens.extend(output_tokens)
|
||||
accumulated_output_logprobs.extend(output_logprobs)
|
||||
# FIXME: Update with actual server versions
|
||||
accumulated_versions.extend([-1] * len(output_tokens))
|
||||
|
||||
# Check if generation is complete
|
||||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += completions
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
completions=completions,
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
def update_weights(self, meta):
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
return executor.submit(self._update_weights, meta)
|
||||
|
||||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "disk":
|
||||
# Update weights from disk
|
||||
try:
|
||||
jobs = [
|
||||
self.aupdate_weights_from_disk(addr, meta.path)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
async def aupdate_weights_from_disk(self, addr, path: str):
|
||||
response, _ = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=path, allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
target_server=addr,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
self.input_queue.put((workflow, data))
|
||||
|
||||
def wait(self, count: int, timeout: int, should_accept: Callable) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
while accepted < count and not self.exiting.is_set() and time.perf_counter() - tik < timeout:
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted += 1
|
||||
except Empty:
|
||||
time.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for {count} rollouts, "
|
||||
f"only received {accepted}."
|
||||
)
|
||||
results, self.result_cache = self.result_cache[:count], self.result_cache[count:]
|
||||
return TensorDict.cat(results, dim=0)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
|
||||
|
||||
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
|
||||
|
||||
# If there is already an image named areal-env, skip.
|
||||
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q 'areal-env:latest'; then
|
||||
echo "Image areal-env already exists, skipping build."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
RUN_ID="areal-$GIT_COMMIT_SHA"
|
||||
cd "/tmp/$RUN_ID"
|
||||
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
|
||||
docker rm -f $RUN_ID
|
||||
fi
|
||||
|
||||
docker run \
|
||||
--name $RUN_ID \
|
||||
--gpus all \
|
||||
--shm-size=8g \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
nvcr.io/nvidia/pytorch:25.01-py3 \
|
||||
bash -c "
|
||||
python -m pip install --upgrade pip
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
pip config unset global.extra-index-url
|
||||
bash examples/env/scripts/setup-pip-deps.sh
|
||||
pip uninstall -y transformer-engine
|
||||
mv ./sglang /sglang
|
||||
" || { docker rm -f $RUN_ID; exit 1; }
|
||||
|
||||
docker commit $RUN_ID areal-env:latest
|
||||
docker rm -f $RUN_ID
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
GIT_REPO_URL=${GIT_REPO_URL:?"GIT_REPO_URL is not set"}
|
||||
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
|
||||
|
||||
echo "GIT_REPO_URL: $GIT_REPO_URL"
|
||||
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
|
||||
|
||||
RUN_ID="areal-$GIT_COMMIT_SHA"
|
||||
rm -rf "/tmp/$RUN_ID"
|
||||
mkdir -p "/tmp/$RUN_ID"
|
||||
cd "/tmp/$RUN_ID"
|
||||
|
||||
git init
|
||||
git remote add origin "$GIT_REPO_URL"
|
||||
git fetch --depth 1 origin "$GIT_COMMIT_SHA"
|
||||
git checkout FETCH_HEAD
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
|
||||
|
||||
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
|
||||
|
||||
RUN_ID="areal-$GIT_COMMIT_SHA"
|
||||
cd "/tmp/$RUN_ID"
|
||||
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
|
||||
docker rm -f $RUN_ID
|
||||
fi
|
||||
|
||||
docker run \
|
||||
--name $RUN_ID \
|
||||
--gpus all \
|
||||
--shm-size=8g \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
areal-env:latest \
|
||||
bash -c "
|
||||
mv /sglang ./sglang
|
||||
HF_ENDPOINT=https://hf-mirror.com python -m pytest -s arealite/
|
||||
" || { docker rm -f $RUN_ID; exit 1; }
|
||||
|
||||
docker rm -f $RUN_ID
|
Loading…
Reference in New Issue