This commit is contained in:
博惟 2025-07-07 13:47:41 +08:00
parent 5b7c83b5d9
commit 1dfe91c470
7 changed files with 489 additions and 0 deletions

68
arealite/api/cli_args.py Normal file
View File

@ -0,0 +1,68 @@
from dataclasses import dataclass, field, asdict
from typing import List
@dataclass
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
n_mbs: int = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
},
)
max_tokens_per_mb: int = field(
default=int(1e12),
metadata={
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
},
)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
fields = dict(
n_mbs=mb_spec.n_mbs,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
)
fields.update(kwargs)
return cls(**fields)
@dataclass
class GenerationHyperparameters:
"""Controls text generation behavior for RL training."""
n_samples: int = field(
default=1, metadata={"help": "Number of sequences to generate per prompt."}
)
max_new_tokens: int = field(
default=16384, metadata={"help": "Maximum number of tokens to generate."}
)
min_new_tokens: int = field(
default=0, metadata={"help": "Minimum number of tokens to generate."}
)
greedy: bool = field(
default=False,
metadata={"help": "Whether to use greedy decoding (max probability)."},
)
top_p: float = field(
default=1.0,
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
)
top_k: int = field(
default=int(1e8),
metadata={"help": "Number of highest probability tokens to consider."},
)
temperature: float = field(
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: List[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encoutering these token ids."},
)
def new(self, **kwargs):
args = asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)

124
arealite/api/engine_api.py Normal file
View File

@ -0,0 +1,124 @@
import abc
from typing import Callable, Dict, List, Any, Optional
import torch
from dataclasses import dataclass, field
from concurrent.futures import Future
from arealite.api.cli_args import MicroBatchSpec
from arealite.api.io_struct import (
LLMRequest,
LLMResponse,
FinetuneSpec,
WeightUpdateMeta,
SaveLoadMeta,
)
@dataclass
class Scheduling:
cpu: int
gpu: int
mem: int
nodelist: str = None
exclude: str = None
partition: str = None
container_image: str = None
env_vars: Dict[str, str] = field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format
begin: Optional[str] = None # see "--begin" option for format
deadline: Optional[str] = None # see "--deadline" option for format
class TrainEngine(abc.ABC):
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize environments for distributed training and load models."""
raise NotImplementedError()
def get_scheduling_config(self) -> Scheduling:
"""Get the scheduling configuration for the engine, e.g., image, cpu/gpu/memory size."""
raise NotImplementedError()
def destroy(self):
"""Destroy the engine and release GPU memory."""
pass
def upload_weights(self, meta: WeightUpdateMeta):
"""Upload weights to the inference engine."""
raise NotImplementedError()
def save(self, meta: SaveLoadMeta):
"""Save model weights (and optimizer states) for later use."""
raise NotImplementedError()
def load(self, meta: SaveLoadMeta):
"""Load model weights and optimizer states from a file."""
raise NotImplementedError()
def step_lr_scheduler(self):
"""Step learning rate scheduler.
Since PPO uses minibatch updates, this method just need to be called once after a few train_batch calls.
It is separated from train_batch to allow for more flexible scheduling.
"""
raise NotImplementedError()
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
"""Update the model with a batch of data and a loss function."""
raise NotImplementedError()
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function."""
raise NotImplementedError()
@torch.no_grad()
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
raise NotImplementedError()
class InferenceEngine(abc.ABC):
def initialize(self, addr: str | None, ft_spec):
"""Initialize environments for distributed inference and load models."""
raise NotImplementedError()
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine."""
raise NotImplementedError()
async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Asynchronously generate a response for the given request."""
raise NotImplementedError()
def submit(self, data: Dict[str, Any], workflow) -> None:
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: int) -> Any:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def rollout(self, data: List[Dict[str, Any]], workflow) -> Any:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()

30
arealite/api/env_api.py Normal file
View File

@ -0,0 +1,30 @@
import abc
from typing import Any, Dict, List, Callable
class Environment(abc.ABC):
async def ainitialize(self):
"""
Performs the initialization logic for the environment asynchronously.
For stateful environments, this is where resources are created and
prepared (e.g., launching a browser).
"""
pass
def list_tools(self) -> List[Dict[str, Any]]:
"""Lists all available tools in the environment."""
return []
async def aexecute(self, tool_name: str, tool_args: Dict[str, Any]) -> Any:
"""Executes a tool in the environment asynchronously."""
raise NotImplementedError()
async def aclose(self):
"""
Destroys the environment asynchronously, releasing all held resources.
This method is critical for stateful environments (e.g., a browser session).
"""
pass

172
arealite/api/io_struct.py Normal file
View File

@ -0,0 +1,172 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import enum
import itertools
import re
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
@dataclass
class LLMRequest:
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
text: Optional[str] = None
input_ids: List[int] = field(default_factory=list)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
metadata: Dict[str, Any] = field(default_factory=dict)
model_id: Optional[str] = None
@dataclass
class LLMResponse:
# outputs
completions: str
input_tokens: List[int] = field(default_factory=list)
output_tokens: List[int] = field(default_factory=list)
output_logprobs: List[float] = field(default_factory=list)
output_versions: List[int] = field(default_factory=list)
stop_reason: Literal["length", "stop", "interrupt"] = "stop"
# statistics
latency: float = float("inf")
ttft: float = float("inf") # Time to first token
itl: List[float] = field(default_factory=list) # List of inter-token latencies
@property
def input_len(self) -> int:
return len(self.input_tokens)
@property
def output_len(self) -> int:
return len(self.output_tokens)
@dataclass
class FinetuneSpec:
total_train_epochs: int
dataset_size: int
train_batch_size: int
@property
def total_train_steps(self):
# assuming drop_last
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
class AllocationType(enum.Enum):
DECOUPLED_vLLM = 1
DECOUPLED_SGLANG = 2
@dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: None | Dict[str, Dict[str, int]]
@property
def gen_tp_size(self) -> int:
return self.parallel_strat["gen"]["t"]
@property
def gen_pp_size(self) -> int:
return self.parallel_strat["gen"]["p"]
@property
def gen_dp_size(self) -> int:
return self.parallel_strat["gen"]["d"]
@property
def gen_world_size(self) -> int:
return self.gen_dp_size * self.gen_pp_size * self.gen_tp_size
@property
def train_tp_size(self) -> int:
return self.parallel_strat["*"]["t"]
@property
def train_pp_size(self) -> int:
return self.parallel_strat["*"]["p"]
@property
def train_dp_size(self) -> int:
return self.parallel_strat["*"]["d"]
@property
def train_world_size(self) -> int:
return self.train_dp_size * self.train_pp_size * self.train_tp_size
@classmethod
def from_str(cls, allocation_mode: str):
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
if "vllm" in allocation_mode:
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
elif "sglang" in allocation_mode:
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
@staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
for x, y, z in itertools.permutations(["d", "t", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
if not m:
continue
a, b, c = map(int, m.groups())
# to be consistent with the key-value pattern
return {
"*": {
x: a,
y: b,
z: c,
}
}
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:
return
if m.group(1):
gen_alloc = m.group(1)
other_alloc = m.group(2)
else:
gen_alloc = m.group(4)
other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
if not gen_alloc:
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc
@dataclass
class WeightUpdateMeta:
type: str
path: str | None
alloc_mode: AllocationMode | None
comm_backend: str | None
@dataclass
class SaveLoadMeta:
path: str
weight_format: str
with_optim: bool
tokenizer: PreTrainedTokenizerFast | None
base_model_path: str | None

View File

@ -0,0 +1,24 @@
from typing import List
def reward_fn(
prompt: str,
completions: str,
prompt_ids: List[int],
completion_ids: List[int],
**kwargs,
):
"""This function is a placeholder for the reward function that will be used in the RLVR pipeline.
In general, there's no restriction on the signature and implementation of this function in customized rollout workflows.
It would be convinent to follow this signature and directly use it in our predefined rollout workflows.
:param prompt: The string representing the task to be completed.
:param completions: The string representing the trajectory generated by the model.
:param prompt_ids: The token IDs of the prompt.
:param completion_ids: The token IDs of the trajectory generated by the model.
:param kwargs: Other attributes of the data in the dataset, such as solutions, input_outputs, etc.
Any other attributes in the dataset will be passed as keyword arguments to this function.
:rtype: float
"""
pass

View File

@ -0,0 +1,17 @@
from typing import TYPE_CHECKING, Dict, Any
from tensordict import TensorDict
if TYPE_CHECKING:
from arealite.api.engine_api import InferenceEngine
class RolloutWorkflow:
async def arun_episode(
self, engine: InferenceEngine, data: Dict[str, Any]
) -> TensorDict:
"""Run a single episode of the workflow.
See concrete example implementations under the `arealite/workflow` directory.
"""
raise NotImplementedError()

54
arealite/workflow/rlvr.py Normal file
View File

@ -0,0 +1,54 @@
from tensordict import TensorDict
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.workflow_api import RolloutWorkflow
from arealite.api.io_struct import LLMRequest
import uuid
import torch
from transformers import PreTrainedTokenizerFast
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
async def arun_episode(self, engine, data):
text = self.tokenizer.apply_chat_template(
data["messages"], tokenize=False, add_generation_prompt=True
)
req = LLMRequest(
rid=uuid.uuid4().hex,
text=text,
gconfig=self.gconfig,
)
resp = await engine.agenerate(req)
seq = resp.input_tokens + resp.output_tokens
logprobs = [0] * resp.input_len + resp.output_logprobs
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)
return TensorDict(res, batch_size=[1])