Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite

This commit is contained in:
garrett4wade 2025-07-07 17:02:21 +08:00
commit 7ab6755379
9 changed files with 884 additions and 3 deletions

50
.github/workflows/test-arealite.yml vendored Normal file
View File

@ -0,0 +1,50 @@
name: Test AReaLite
on:
push:
paths:
- .github/workflows/test-arealite.yml
- arealite/**
- ci/**
workflow_dispatch:
jobs:
test-arealite:
runs-on: ubuntu-latest
concurrency:
group: test-arealite
steps:
- uses: actions/checkout@v4
- uses: appleboy/ssh-action@v1
env:
GIT_REPO_URL: https://github.bibk.top/${{ github.repository }}
GIT_COMMIT_SHA: ${{ github.sha }}
with:
host: ${{ secrets.CI_NODE_ADDR }}
username: ${{ secrets.CI_NODE_USER }}
key: ${{ secrets.REMOTE_SSH_KEY }}
envs: GIT_REPO_URL,GIT_COMMIT_SHA
script_path: ci/clone_repo.sh
- uses: appleboy/ssh-action@v1
env:
GIT_COMMIT_SHA: ${{ github.sha }}
with:
host: ${{ secrets.CI_NODE_ADDR }}
username: ${{ secrets.CI_NODE_USER }}
key: ${{ secrets.REMOTE_SSH_KEY }}
command_timeout: 2h
envs: GIT_COMMIT_SHA
script_path: ci/build_env_image.sh
- uses: appleboy/ssh-action@v1
env:
GIT_COMMIT_SHA: ${{ github.sha }}
with:
host: ${{ secrets.CI_NODE_ADDR }}
username: ${{ secrets.CI_NODE_USER }}
key: ${{ secrets.REMOTE_SSH_KEY }}
command_timeout: 1h
envs: GIT_COMMIT_SHA
script_path: ci/test_arealite.sh

View File

@ -68,3 +68,23 @@ class GenerationHyperparameters:
args = asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
@dataclass
class InferenceEngineConfig:
# Used by remote inference engines.
server_addrs: List[str] = field(
default_factory=list,
metadata={"help": "List of server addresses for inference."}
)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
request_timeout: float = field(
default=30.0,
metadata={"help": "Timeout for HTTP requests."}
)
request_retries: int = field(
default=3,
metadata={"help": "Number of retries for failed requests."}
)

View File

@ -1,4 +1,7 @@
import abc
from typing import Callable, Dict, List, Any, Optional, TYPE_CHECKING
import torch
from dataclasses import dataclass, field
from concurrent.futures import Future
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
@ -13,6 +16,9 @@ from arealite.api.io_struct import (
SaveLoadMeta,
WeightUpdateMeta,
)
from tensordict import TensorDict
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
@dataclass
@ -113,14 +119,14 @@ class InferenceEngine(abc.ABC):
"""Asynchronously generate a response for the given request."""
raise NotImplementedError()
def submit(self, data: Dict[str, Any], workflow) -> None:
def submit(self, data: Dict[str, Any], workflow:"RolloutWorkflow") -> None:
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: int) -> Any:
def wait(self, count: int, timeout: int) -> TensorDict:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def rollout(self, data: List[Dict[str, Any]], workflow) -> Any:
def rollout(self, data: List[Dict[str, Any]], workflow:"RolloutWorkflow") -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()

View File

@ -170,3 +170,9 @@ class SaveLoadMeta:
with_optim: bool
tokenizer: PreTrainedTokenizerFast | None
base_model_path: str | None
@dataclass
class RolloutStat:
submitted: int = 0
accepted: int = 0
running: int = 0

315
arealite/engine/hf.py Normal file
View File

@ -0,0 +1,315 @@
import asyncio
import functools
import math
import os
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from arealite.api.cli_args import (
EngineConfig,
MicroBatchSpec,
ParallelismConfig,
TrainingArgs,
)
from arealite.api.engine_api import TrainEngine
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
from arealite.utils import (
get_state_dict_from_repo_id_or_path,
recorder_list,
split_dict_tensor_with_cu_seqlens,
unpack_sequence,
)
from realhf.base import constants
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
class HFEngine(TrainEngine):
"""Simplified HF engine for transformer models."""
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
super().__init__(args, engine_config)
self.model = None
self.optimizer = None
self.model_config = None
self.weight_update_group_initialized = False
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
"""Initialize model in single node."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
if dist.get_world_size() > 1:
raise RuntimeError(
"Distributed training is not supported in this engine. "
"Please use FSDP for distributed training."
)
torch.cuda.set_device("cuda:0")
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.engine_config.path,
trust_remote_code=True,
)
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
)
model = model.cuda()
self.model = model
# Set up optimizer
optimizer_config = self.engine_config.optimizer
if optimizer_config is not None:
assert (
optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = optimizer_config.lr
weight_decay = optimizer_config.weight_decay
beta1 = optimizer_config.beta1
beta2 = optimizer_config.beta2
eps = optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
optimizer_config.warmup_steps_proportion * total_train_steps
)
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=optimizer_config.min_lr_ratio,
)
def train(self, mode: bool = True):
"""Set the module in training mode."""
return self.model.train(mode)
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
"""Train on a batch using gradient accumulation."""
assert self.optimizer is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.engine_config.optimizer.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return {
"grad_norm": grad_norm,
"lr": current_lr,
}
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
) -> Any | None:
"""Forward pass with optional post-processing."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
if output_seqlens is None:
cu_seqlens = input_["cu_seqlens"]
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
if post_hook:
result = post_hook(outputs.logits, mb_input)
results.append(result)
else:
results.append(outputs.logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
def step_lr_scheduler(self):
"""Step the learning rate scheduler."""
return self.lr_scheduler.step()
def save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
base_model_path: Optional[str] = None,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
os.makedirs(path, exist_ok=True)
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
self.model.save_pretrained(path, state_dict=state_dict)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
full_state = get_state_dict_from_repo_id_or_path(path)
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def save_optimizer_state(self, path: str):
"""Save optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
def load_optimizer_state(self, path: str):
"""Load optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
optimizer_path = os.path.join(path, "optimizer.pt")
if os.path.exists(optimizer_path):
self.optimizer.load_state_dict(
torch.load(optimizer_path, map_location="cpu")
)
else:
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
async def aupdate_weights_to(self, llm_client: LLMClient):
path = constants.get_param_realloc_path(self.args)
self.save_model_to_hf(path)
tasks = [
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
for server_info in llm_client.get_healthy_servers()
]
await asyncio.gather(*tasks)
def update_weights_to(self, llm_client: LLMClient):
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(self.aupdate_weights_to(llm_client))
finally:
loop.close()

View File

@ -0,0 +1,398 @@
import time
from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
WeightUpdateMeta,
RolloutStat,
)
from arealite.api.engine_api import InferenceEngine
from realhf.base import logging, pkg_version
import asyncio
import aiohttp
from tensordict import TensorDict
from typing import Dict, Any, Optional, TYPE_CHECKING, Callable
from arealite.api.cli_args import InferenceEngineConfig
from concurrent.futures import ThreadPoolExecutor
from queue import Queue, Empty
import torch.distributed as dist
import traceback
import threading
from realhf.base import name_resolve, names
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
logger = logging.getLogger(__name__)
if pkg_version.is_available("sglang"):
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.4
class RemoteSGLangEngine(InferenceEngine):
def __init__(self, config: InferenceEngineConfig):
self.config = config
self.rid_to_address = {}
self.addresses = config.server_addrs
self.server_idx = 0
self.input_queue = Queue(maxsize=config.max_concurrent_rollouts)
self.output_queue = Queue(maxsize=config.max_concurrent_rollouts)
self.result_cache = []
self.exiting = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
def _get_model_version(self) -> int:
name = names.model_version(
self.config.experiment_name,
self.config.trial_name,
"actor",
)
try:
return int(name_resolve.get(name))
except name_resolve.NameEntryNotFoundError:
return 0
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
asyncio.run_coroutine_threadsafe(self._rollout_thread_async())
finally:
self.exiting.set()
async def _rollout_thread_async(self):
data = None
rollout_tasks: Dict[int, asyncio.Task] = {}
rid = 0
try:
while not self.exiting.is_set():
# Load next data from controller
if data is None:
try:
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
except Empty:
logger.debug(f"No data from puller stream.")
# Check capacity
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
cannot_rollout_reason = []
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
can_rollout = len(rollout_tasks) < capacity
if not can_rollout:
cannot_rollout_reason.append(
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
)
# Staleness control
version = self._get_model_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
expected_version = sample_cnt // self.train_batch_size
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
if not can_rollout:
logger.debug(
f"Cannot submit new rollouts. "
+ "\n".join(cannot_rollout_reason)
)
# Create new rollout task
if can_rollout and data is not None:
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
rollout_tasks[rid] = task
with self.lock:
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.debug(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
rid += 1
data = None
# Wait for rollout completion
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
else:
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
rollout_tasks.pop(task_rid)
self.output_queue.put(traj)
with self.lock:
self.rollout_stat.running -= 1
logger.debug(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
finally:
# Cancel remaining tasks
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def choose_server(self) -> str:
if self.config.schedule_policy == "round_robin":
server = self.addresses[self.server_idx]
self.server_idx = (self.server_idx + 1) % len(self.addresses)
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def arequest_with_retry(
self,
endpoint: str,
payload: Optional[Dict[str, Any]] = None,
method: str = "POST",
max_retries: Optional[int] = None,
timeout: Optional[float] = None,
retry_delay: float = 1.0,
target_addr: Optional[str] = None,
) -> aiohttp.ClientResponse:
timeout = timeout or self.config.request_timeout
last_exception = None
max_retries = max_retries or self.config.request_retries
# Try with retries
for _ in range(max_retries):
if target_addr:
addr = target_addr
else:
addr = self.choose_server()
base_url = f"http://{addr}"
url = f"{base_url}{endpoint}"
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=30,
sock_read=timeout,
)
) as session:
if method.upper() == "GET":
response = await session.get(url)
elif method.upper() == "POST":
response = await session.post(url, json=payload)
elif method.upper() == "PUT":
response = await session.put(url, json=payload)
elif method.upper() == "DELETE":
response = await session.delete(url)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
asyncio.TimeoutError,
) as e:
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
)
async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
assert gconfig.n_samples == 1
sample_params = {
"top_p": gconfig.top_p,
"top_k": gconfig.top_k,
"max_new_tokens": gconfig.max_new_tokens,
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
"stop_token_ids": stop_token_ids,
}
payload = {
"rid": req.rid,
"text": req.text,
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
if req.text:
payload["text"] = req.text
else:
payload["input_ids"] = req.input_ids
# Make request
start_time = time.perf_counter()
accumulated_output_tokens = []
accumulated_output_logprobs = []
accumulated_versions = []
# Deal with rollout interruption
completions = ""
stop_reason = "length"
while (
stop_reason != "stop"
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# loop until the generation is complete
response = await self.arequest_with_retry(
endpoint="/generate",
payload=payload,
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
)
result = await response.json()
# Parse response
completions += result["text"]
meta_info = result["meta_info"]
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
# Update accumulated outputs
accumulated_output_tokens.extend(output_tokens)
accumulated_output_logprobs.extend(output_logprobs)
# FIXME: Update with actual server versions
accumulated_versions.extend([-1] * len(output_tokens))
# Check if generation is complete
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += completions
latency = time.perf_counter() - start_time
return LLMResponse(
completions=completions,
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
def update_weights(self, meta):
executor = ThreadPoolExecutor(max_workers=1)
return executor.submit(self._update_weights, meta)
def _update_weights(self, meta: WeightUpdateMeta):
if meta.type == "disk":
# Update weights from disk
try:
jobs = [
self.aupdate_weights_from_disk(addr, meta.path)
for addr in self.addresses
]
loop = asyncio.new_event_loop()
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
async def aupdate_weights_from_disk(self, addr, path: str):
response, _ = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=path, allow_interrupt=True),
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
target_server=addr,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
self.input_queue.put((workflow, data))
def wait(self, count: int, timeout: int, should_accept: Callable) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
while accepted < count and not self.exiting.is_set() and time.perf_counter() - tik < timeout:
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept(result):
self.result_cache.append(result)
accepted += 1
with self.lock:
self.rollout_stat.accepted += 1
except Empty:
time.sleep(ROLLOUT_POLL_WAIT_TIME)
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
raise TimeoutError(
f"Timed out waiting for {count} rollouts, "
f"only received {accepted}."
)
results, self.result_cache = self.result_cache[:count], self.result_cache[count:]
return TensorDict.cat(results, dim=0)

39
ci/build_env_image.sh Normal file
View File

@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -e
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
# If there is already an image named areal-env, skip.
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q 'areal-env:latest'; then
echo "Image areal-env already exists, skipping build."
exit 0
fi
RUN_ID="areal-$GIT_COMMIT_SHA"
cd "/tmp/$RUN_ID"
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
docker rm -f $RUN_ID
fi
docker run \
--name $RUN_ID \
--gpus all \
--shm-size=8g \
-v $(pwd):/workspace \
-w /workspace \
nvcr.io/nvidia/pytorch:25.01-py3 \
bash -c "
python -m pip install --upgrade pip
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip config unset global.extra-index-url
bash examples/env/scripts/setup-pip-deps.sh
pip uninstall -y transformer-engine
mv ./sglang /sglang
" || { docker rm -f $RUN_ID; exit 1; }
docker commit $RUN_ID areal-env:latest
docker rm -f $RUN_ID

19
ci/clone_repo.sh Normal file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env bash
set -e
GIT_REPO_URL=${GIT_REPO_URL:?"GIT_REPO_URL is not set"}
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
echo "GIT_REPO_URL: $GIT_REPO_URL"
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
RUN_ID="areal-$GIT_COMMIT_SHA"
rm -rf "/tmp/$RUN_ID"
mkdir -p "/tmp/$RUN_ID"
cd "/tmp/$RUN_ID"
git init
git remote add origin "$GIT_REPO_URL"
git fetch --depth 1 origin "$GIT_COMMIT_SHA"
git checkout FETCH_HEAD

28
ci/test_arealite.sh Normal file
View File

@ -0,0 +1,28 @@
#!/usr/bin/env bash
set -e
GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
RUN_ID="areal-$GIT_COMMIT_SHA"
cd "/tmp/$RUN_ID"
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
docker rm -f $RUN_ID
fi
docker run \
--name $RUN_ID \
--gpus all \
--shm-size=8g \
-v $(pwd):/workspace \
-w /workspace \
areal-env:latest \
bash -c "
mv /sglang ./sglang
HF_ENDPOINT=https://hf-mirror.com python -m pytest -s arealite/
" || { docker rm -f $RUN_ID; exit 1; }
docker rm -f $RUN_ID