mirror of https://github.com/inclusionAI/AReaL
add api
This commit is contained in:
parent
5b7c83b5d9
commit
1dfe91c470
|
@ -0,0 +1,68 @@
|
|||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class MicroBatchSpec:
|
||||
"""Specification for splitting micro-batches during training."""
|
||||
|
||||
n_mbs: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
||||
},
|
||||
)
|
||||
max_tokens_per_mb: int = field(
|
||||
default=int(1e12),
|
||||
metadata={
|
||||
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
|
||||
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
|
||||
fields = dict(
|
||||
n_mbs=mb_spec.n_mbs,
|
||||
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
|
||||
)
|
||||
fields.update(kwargs)
|
||||
return cls(**fields)
|
||||
|
||||
@dataclass
|
||||
class GenerationHyperparameters:
|
||||
"""Controls text generation behavior for RL training."""
|
||||
|
||||
n_samples: int = field(
|
||||
default=1, metadata={"help": "Number of sequences to generate per prompt."}
|
||||
)
|
||||
max_new_tokens: int = field(
|
||||
default=16384, metadata={"help": "Maximum number of tokens to generate."}
|
||||
)
|
||||
min_new_tokens: int = field(
|
||||
default=0, metadata={"help": "Minimum number of tokens to generate."}
|
||||
)
|
||||
greedy: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use greedy decoding (max probability)."},
|
||||
)
|
||||
top_p: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
|
||||
)
|
||||
top_k: int = field(
|
||||
default=int(1e8),
|
||||
metadata={"help": "Number of highest probability tokens to consider."},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sampling temperature. Higher values increase diversity."},
|
||||
)
|
||||
stop_token_ids: List[int] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Stop generation when encoutering these token ids."},
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
args = asdict(self)
|
||||
args.update(kwargs)
|
||||
return GenerationHyperparameters(**args)
|
|
@ -0,0 +1,124 @@
|
|||
import abc
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from concurrent.futures import Future
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from arealite.api.io_struct import (
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
FinetuneSpec,
|
||||
WeightUpdateMeta,
|
||||
SaveLoadMeta,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scheduling:
|
||||
cpu: int
|
||||
gpu: int
|
||||
mem: int
|
||||
nodelist: str = None
|
||||
exclude: str = None
|
||||
partition: str = None
|
||||
container_image: str = None
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
# time utils from "https://slurm.schedmd.com/sbatch.html"
|
||||
time_limit: Optional[str] = None # see "--time" option for format
|
||||
begin: Optional[str] = None # see "--begin" option for format
|
||||
deadline: Optional[str] = None # see "--deadline" option for format
|
||||
|
||||
|
||||
class TrainEngine(abc.ABC):
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
"""Initialize environments for distributed training and load models."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_scheduling_config(self) -> Scheduling:
|
||||
"""Get the scheduling configuration for the engine, e.g., image, cpu/gpu/memory size."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
pass
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
"""Upload weights to the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
"""Save model weights (and optimizer states) for later use."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load(self, meta: SaveLoadMeta):
|
||||
"""Load model weights and optimizer states from a file."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
"""Step learning rate scheduler.
|
||||
|
||||
Since PPO uses minibatch updates, this method just need to be called once after a few train_batch calls.
|
||||
It is separated from train_batch to allow for more flexible scheduling.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Update the model with a batch of data and a loss function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate the model using the forward pass and loss function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class InferenceEngine(abc.ABC):
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec):
|
||||
"""Initialize environments for distributed inference and load models."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||
"""Update weights in the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
"""Asynchronously generate a response for the given request."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow) -> None:
|
||||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: int) -> Any:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(self, data: List[Dict[str, Any]], workflow) -> Any:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,30 @@
|
|||
import abc
|
||||
from typing import Any, Dict, List, Callable
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
|
||||
async def ainitialize(self):
|
||||
"""
|
||||
Performs the initialization logic for the environment asynchronously.
|
||||
|
||||
For stateful environments, this is where resources are created and
|
||||
prepared (e.g., launching a browser).
|
||||
"""
|
||||
pass
|
||||
|
||||
def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Lists all available tools in the environment."""
|
||||
return []
|
||||
|
||||
async def aexecute(self, tool_name: str, tool_args: Dict[str, Any]) -> Any:
|
||||
"""Executes a tool in the environment asynchronously."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aclose(self):
|
||||
"""
|
||||
Destroys the environment asynchronously, releasing all held resources.
|
||||
|
||||
This method is critical for stateful environments (e.g., a browser session).
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import enum
|
||||
import itertools
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRequest:
|
||||
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text: Optional[str] = None
|
||||
input_ids: List[int] = field(default_factory=list)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
model_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
# outputs
|
||||
completions: str
|
||||
input_tokens: List[int] = field(default_factory=list)
|
||||
output_tokens: List[int] = field(default_factory=list)
|
||||
output_logprobs: List[float] = field(default_factory=list)
|
||||
output_versions: List[int] = field(default_factory=list)
|
||||
stop_reason: Literal["length", "stop", "interrupt"] = "stop"
|
||||
|
||||
# statistics
|
||||
latency: float = float("inf")
|
||||
ttft: float = float("inf") # Time to first token
|
||||
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
||||
|
||||
@property
|
||||
def input_len(self) -> int:
|
||||
return len(self.input_tokens)
|
||||
|
||||
@property
|
||||
def output_len(self) -> int:
|
||||
return len(self.output_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuneSpec:
|
||||
total_train_epochs: int
|
||||
dataset_size: int
|
||||
train_batch_size: int
|
||||
|
||||
@property
|
||||
def total_train_steps(self):
|
||||
# assuming drop_last
|
||||
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
|
||||
|
||||
|
||||
class AllocationType(enum.Enum):
|
||||
DECOUPLED_vLLM = 1
|
||||
DECOUPLED_SGLANG = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllocationMode:
|
||||
type_: AllocationType
|
||||
parallel_strat: None | Dict[str, Dict[str, int]]
|
||||
|
||||
@property
|
||||
def gen_tp_size(self) -> int:
|
||||
return self.parallel_strat["gen"]["t"]
|
||||
|
||||
@property
|
||||
def gen_pp_size(self) -> int:
|
||||
return self.parallel_strat["gen"]["p"]
|
||||
|
||||
@property
|
||||
def gen_dp_size(self) -> int:
|
||||
return self.parallel_strat["gen"]["d"]
|
||||
|
||||
@property
|
||||
def gen_world_size(self) -> int:
|
||||
return self.gen_dp_size * self.gen_pp_size * self.gen_tp_size
|
||||
|
||||
@property
|
||||
def train_tp_size(self) -> int:
|
||||
return self.parallel_strat["*"]["t"]
|
||||
|
||||
@property
|
||||
def train_pp_size(self) -> int:
|
||||
return self.parallel_strat["*"]["p"]
|
||||
|
||||
@property
|
||||
def train_dp_size(self) -> int:
|
||||
return self.parallel_strat["*"]["d"]
|
||||
|
||||
@property
|
||||
def train_world_size(self) -> int:
|
||||
return self.train_dp_size * self.train_pp_size * self.train_tp_size
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, allocation_mode: str):
|
||||
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
|
||||
if "vllm" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
|
||||
elif "sglang" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
|
||||
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
|
||||
|
||||
@staticmethod
|
||||
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
|
||||
for x, y, z in itertools.permutations(["d", "t", "p"]):
|
||||
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
||||
m = re.match(pattern, allocation_mode)
|
||||
if not m:
|
||||
continue
|
||||
a, b, c = map(int, m.groups())
|
||||
# to be consistent with the key-value pattern
|
||||
return {
|
||||
"*": {
|
||||
x: a,
|
||||
y: b,
|
||||
z: c,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
||||
pattern = re.compile(
|
||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||
)
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
return
|
||||
if m.group(1):
|
||||
gen_alloc = m.group(1)
|
||||
other_alloc = m.group(2)
|
||||
else:
|
||||
gen_alloc = m.group(4)
|
||||
other_alloc = m.group(3)
|
||||
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
|
||||
if not gen_alloc:
|
||||
return
|
||||
other_alloc = AllocationMode.extract_3d_alloc(
|
||||
other_alloc
|
||||
) or AllocationMode.extract_key_value_alloc(other_alloc)
|
||||
if not other_alloc:
|
||||
return
|
||||
other_alloc.update({"gen": gen_alloc["*"]})
|
||||
return other_alloc
|
||||
|
||||
@dataclass
|
||||
class WeightUpdateMeta:
|
||||
type: str
|
||||
path: str | None
|
||||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveLoadMeta:
|
||||
path: str
|
||||
weight_format: str
|
||||
with_optim: bool
|
||||
tokenizer: PreTrainedTokenizerFast | None
|
||||
base_model_path: str | None
|
|
@ -0,0 +1,24 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
def reward_fn(
|
||||
prompt: str,
|
||||
completions: str,
|
||||
prompt_ids: List[int],
|
||||
completion_ids: List[int],
|
||||
**kwargs,
|
||||
):
|
||||
"""This function is a placeholder for the reward function that will be used in the RLVR pipeline.
|
||||
|
||||
In general, there's no restriction on the signature and implementation of this function in customized rollout workflows.
|
||||
It would be convinent to follow this signature and directly use it in our predefined rollout workflows.
|
||||
|
||||
:param prompt: The string representing the task to be completed.
|
||||
:param completions: The string representing the trajectory generated by the model.
|
||||
:param prompt_ids: The token IDs of the prompt.
|
||||
:param completion_ids: The token IDs of the trajectory generated by the model.
|
||||
:param kwargs: Other attributes of the data in the dataset, such as solutions, input_outputs, etc.
|
||||
Any other attributes in the dataset will be passed as keyword arguments to this function.
|
||||
:rtype: float
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,17 @@
|
|||
from typing import TYPE_CHECKING, Dict, Any
|
||||
from tensordict import TensorDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
|
||||
|
||||
class RolloutWorkflow:
|
||||
|
||||
async def arun_episode(
|
||||
self, engine: InferenceEngine, data: Dict[str, Any]
|
||||
) -> TensorDict:
|
||||
"""Run a single episode of the workflow.
|
||||
|
||||
See concrete example implementations under the `arealite/workflow` directory.
|
||||
"""
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,54 @@
|
|||
from tensordict import TensorDict
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
import uuid
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
data["messages"], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
text=text,
|
||||
gconfig=self.gconfig,
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0] * resp.input_len + resp.output_logprobs
|
||||
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
|
||||
return TensorDict(res, batch_size=[1])
|
Loading…
Reference in New Issue