PullRequest: 331 [lite] Support remote sglang engine with corresponding testcases.

Merge branch fw/lite of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/331

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* add test for sglang remote engine
* fix
This commit is contained in:
博惟 2025-07-09 14:18:55 +08:00
parent 57b9b945ab
commit 8771778995
10 changed files with 510 additions and 60 deletions

View File

@ -83,7 +83,6 @@ def main_grpo():
# or asynchronous rollout with filtering and off-policyness control
# rollout_batch = rollout.prepare_batch(batch,
# workflow=MyRolloutWorkflow(rollout_config.workflow),
# offpolicyness=4,
# should_accept=lambda x: x['rewards'].mean() > 0)
# In the single-controller mode

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field
from typing import List
from typing import Dict, List, Optional
@dataclass
@ -70,8 +70,157 @@ class GenerationHyperparameters:
return GenerationHyperparameters(**args)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
https://github.com/sgl-project/sglang for detailed documentation.
"""
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "flashinfer"
sampling_backend: Optional[str] = None
context_length: Optional[int] = 32768
mem_fraction_static: Optional[float] = 0.9
max_running_requests: Optional[int] = None
# NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang,
# but we disable it to avoid precision issues
chunked_prefill_size: Optional[int] = -1
max_prefill_tokens: int = 32768
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False
enable_metrics: bool = True # Exports Prometheus-like metrics
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: int = 1
# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
base_gpu_id,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
):
from realhf.base import network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args["random_seed"] = seeding.get_seed()
if served_model_name is None:
served_model_name = model_path
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
host=host,
model_path=model_path,
# Model and tokenizer
tokenizer_path=model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
device="cuda",
served_model_name=served_model_name,
is_embedding=False,
skip_tokenizer_init=skip_tokenizer_init,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=base_gpu_id,
nnodes=1,
node_rank=0,
dist_init_addr=dist_init_addr,
**args,
)
if pkg_version.is_version_less("sglang", "0.4.4"):
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')} ")
continue
if isinstance(v, list):
values = " ".join(map(str, v))
flags.append(f"--{k.replace('_','-')} {values}")
continue
flags.append(f"--{k.replace('_','-')} {v}")
flags = " ".join(flags)
return f"python3 -m sglang.launch_server {flags}"
@dataclass
class InferenceEngineConfig:
experiment_name: str
trial_name: str
max_concurrent_rollouts: None | int = field(
default=None,
metadata={
"help": "Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size."
},
)
queue_size: None | int = field(
default=None,
metadata={"help": "Input/Output queue size for async rollout."},
)
consumer_batch_size: int = field(
default=1,
metadata={"help": "Batch size for consuming rollouts from the queue."},
)
max_head_offpolicyness: int = field(
default=0,
metadata={
"help": "Maximum off-policyness for the head. "
"If the current version is more than this many versions behind, "
"the request will not be accepted.",
},
)
# Used by remote inference engines.
server_addrs: List[str] = field(
default_factory=list,

View File

@ -109,6 +109,10 @@ class InferenceEngine(abc.ABC):
"""Initialize environments for distributed inference and load models."""
raise NotImplementedError()
def destroy(self):
"""Destroy the engine and release GPU memory."""
pass
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine."""
raise NotImplementedError()
@ -121,7 +125,7 @@ class InferenceEngine(abc.ABC):
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: int) -> TensorDict:
def wait(self, count: int, timeout: float) -> TensorDict:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
class RolloutWorkflow:
async def arun_episode(
self, engine: InferenceEngine, data: Dict[str, Any]
self, engine: "InferenceEngine", data: Dict[str, Any]
) -> TensorDict:
"""Run a single episode of the workflow.

View File

@ -3,8 +3,8 @@ import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import aiohttp
import torch.distributed as dist
@ -18,7 +18,7 @@ from arealite.api.io_struct import (
RolloutStat,
WeightUpdateMeta,
)
from realhf.base import logging, name_resolve, names, pkg_version
from realhf.base import logging, pkg_version
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
@ -31,19 +31,27 @@ if pkg_version.is_available("sglang"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.4
RID_CACHE_SIZE = 128
class RemoteSGLangEngine(InferenceEngine):
def __init__(self, config: InferenceEngineConfig):
config.max_concurrent_rollouts = (
config.max_concurrent_rollouts or config.consumer_batch_size
)
self.config = config
self.rid_to_address = {}
# Maintain the addresses for the recent 128 requests
self.rid_queue = []
self.addresses = config.server_addrs
self.server_idx = 0
self.input_queue = Queue(maxsize=config.max_concurrent_rollouts)
self.output_queue = Queue(maxsize=config.max_concurrent_rollouts)
qsize = config.queue_size or config.max_concurrent_rollouts * 10
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
@ -51,32 +59,35 @@ class RemoteSGLangEngine(InferenceEngine):
self.rollout_stat = RolloutStat()
def _get_model_version(self) -> int:
name = names.model_version(
self.config.experiment_name,
self.config.trial_name,
"actor",
)
try:
return int(name_resolve.get(name))
except name_resolve.NameEntryNotFoundError:
return 0
self._version = 0
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
def destroy(self):
self.exiting.set()
self.rollout_thread.join()
def set_version(self, version):
with self.lock:
self._version = version
def get_version(self):
with self.lock:
return self._version
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
asyncio.run_coroutine_threadsafe(self._rollout_thread_async())
finally:
self.exiting.set()
asyncio.run(self._rollout_thread_async())
except Exception as e:
traceback.print_exc()
async def _rollout_thread_async(self):
data = None
rollout_tasks: Dict[int, asyncio.Task] = {}
rollout_tasks: Dict[str, asyncio.Task] = {}
rid = 0
try:
@ -85,7 +96,7 @@ class RemoteSGLangEngine(InferenceEngine):
if data is None:
try:
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
logger.info(f"Get data from puller: {data}")
except Empty:
logger.debug(f"No data from puller stream.")
@ -104,17 +115,17 @@ class RemoteSGLangEngine(InferenceEngine):
)
# Staleness control
version = self._get_model_version()
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
expected_version = sample_cnt // self.train_batch_size
expected_version = sample_cnt // self.config.consumer_batch_size
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
@ -130,12 +141,12 @@ class RemoteSGLangEngine(InferenceEngine):
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
rollout_tasks[rid] = task
rollout_tasks[str(rid)] = task
with self.lock:
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.debug(
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
@ -163,12 +174,18 @@ class RemoteSGLangEngine(InferenceEngine):
traj: TensorDict
task_rid = task.get_name()
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
self.output_queue.put(traj)
try:
self.output_queue.put_nowait(traj)
except Full:
raise RuntimeError(
"Output queue full. Please increase queue_size."
)
with self.lock:
self.rollout_stat.running -= 1
logger.debug(
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
@ -256,7 +273,11 @@ class RemoteSGLangEngine(InferenceEngine):
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
assert gconfig.n_samples == 1
if gconfig.n_samples != 1:
raise ValueError(
"RemoteSGLangEngine does not support n_samples > 1. "
"Please call generate for multiple times with n_samples = 1."
)
sample_params = {
"top_p": gconfig.top_p,
"top_k": gconfig.top_k,
@ -265,8 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
"stop_token_ids": stop_token_ids,
}
# NOTE: rid should NOT be passed in payload
payload = {
"rid": req.rid,
"text": req.text,
"sampling_params": sample_params,
"return_logprob": True,
@ -287,6 +308,17 @@ class RemoteSGLangEngine(InferenceEngine):
completions = ""
stop_reason = "length"
if req.rid in self.rid_to_address:
server_addr = self.rid_to_address[req.rid]
else:
server_addr = self.choose_server()
if len(self.rid_queue) >= RID_CACHE_SIZE:
# Remove the oldest entry if cache is full
oldest_rid = self.rid_queue.pop(0)
self.rid_to_address.pop(oldest_rid, None)
self.rid_to_address[req.rid] = server_addr
self.rid_queue.append(req.rid)
while (
stop_reason != "stop"
and len(accumulated_output_tokens) < gconfig.max_new_tokens
@ -298,6 +330,7 @@ class RemoteSGLangEngine(InferenceEngine):
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
target_addr=server_addr,
)
result = await response.json()
@ -369,9 +402,12 @@ class RemoteSGLangEngine(InferenceEngine):
)
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
self.input_queue.put((workflow, data))
try:
self.input_queue.put_nowait((data, workflow))
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(self, count: int, timeout: int, should_accept: Callable) -> TensorDict:
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
while (
@ -384,8 +420,9 @@ class RemoteSGLangEngine(InferenceEngine):
if should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted += 1
self.rollout_stat.accepted -= 1
except Empty:
time.sleep(ROLLOUT_POLL_WAIT_TIME)
if self.exiting.is_set():
@ -399,3 +436,15 @@ class RemoteSGLangEngine(InferenceEngine):
self.result_cache[count:],
)
return TensorDict.cat(results, dim=0)
def rollout(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(
count=len(data),
timeout=self.config.request_timeout,
should_accept=lambda x: True,
)

View File

@ -0,0 +1,188 @@
import os
import subprocess
import sys
import time
import uuid
import pytest
import requests
import torch
from tensordict import TensorDict
from arealite.api.cli_args import (
GenerationHyperparameters,
InferenceEngineConfig,
SGLangConfig,
)
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import name_resolve, network, seeding
EXPR_NAME = "test_sglang_engine"
TRIAL_NAME = "trial_0"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13887
DIST_PORT = 15887
HOST = network.gethostip()
def check_server_health(base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(mem_fraction_static=0.3),
model_path=MODEL_PATH,
tp_size=1,
base_gpu_id=0,
dist_init_addr=f"{HOST}:{DIST_PORT}",
served_model_name=MODEL_PATH,
skip_tokenizer_init=False,
)
# Launch process
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
process = subprocess.Popen(
full_command.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < 90:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > 90:
raise RuntimeError("server launch failed")
yield
process.terminate()
@pytest.mark.skip("")
@pytest.mark.asyncio
async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
text="hello! how are you today",
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids
assert (
len(resp.output_logprobs)
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completions, str)
@pytest.mark.skip("")
@pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_remote_sglang_rollout(sglang_server, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.workflow.rlvr import RLVRWorkflow
config = InferenceEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
max_concurrent_rollouts=2,
consumer_batch_size=2,
)
config.server_addrs = [f"{HOST}:{PORT}"]
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
gconfig = GenerationHyperparameters(
max_new_tokens=16, greedy=False, n_samples=n_samples
)
tokenizer = load_hf_tokenizer(MODEL_PATH)
workflow = RLVRWorkflow(
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
result = engine.rollout([data] * 2, workflow=workflow)
assert isinstance(result, TensorDict)
bs = result.batch_size
assert bs == torch.Size([2 * n_samples])
engine.destroy()
@pytest.mark.parametrize("ofp", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("bs", [2, 4])
@pytest.mark.parametrize("n_samples", [2, 1])
def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.workflow.rlvr import RLVRWorkflow
config = InferenceEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
consumer_batch_size=bs,
max_head_offpolicyness=ofp,
)
config.server_addrs = [f"{HOST}:{PORT}"]
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
gconfig = GenerationHyperparameters(
max_new_tokens=16, greedy=False, n_samples=n_samples
)
tokenizer = load_hf_tokenizer(MODEL_PATH)
workflow = RLVRWorkflow(
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
for _ in range(bs * 2):
engine.submit(data, workflow=workflow)
# wait for some time
time.sleep(15)
assert engine.output_queue.qsize() == min(bs * 2, bs * (ofp + 1))
# Update model version
engine.set_version(1)
print("Updated model version", flush=True)
# submit again
for _ in range(bs * 2):
engine.submit(data, workflow=workflow)
# wait for some time
time.sleep(15)
assert engine.output_queue.qsize() == min(bs * 4, bs * (ofp + 2))
# exit
engine.destroy()

48
arealite/utils/padding.py Normal file
View File

@ -0,0 +1,48 @@
from typing import Any, Dict, List
import torch
from tensordict import TensorDict
def concat_padded_tensors(
tensor_dicts: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return TensorDict()
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
# Find max sequence length across all dictionaries
assert all("attention_mask" in td for td in tensor_dicts)
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
result = {}
# Process each key
for key in tensor_dicts[0].keys():
tensors_to_concat = []
for tensor_dict in tensor_dicts:
tensor = tensor_dict[key]
# Skip 1D tensors like rewards
if len(tensor.shape) == 1:
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if current_length < max_length:
# Pad tensor to max_length
pad_width = max_length - current_length
if key == "attention_mask":
# Pad attention mask with 0s
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
)
tensor = torch.cat([tensor, padding], dim=1)
tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0)
return TensorDict(result, batch_size=new_batch_size)

View File

@ -1,3 +1,4 @@
import asyncio
import uuid
import torch
@ -7,6 +8,7 @@ from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.padding import concat_padded_tensors
class RLVRWorkflow(RolloutWorkflow):
@ -24,33 +26,38 @@ class RLVRWorkflow(RolloutWorkflow):
text = self.tokenizer.apply_chat_template(
data["messages"], tokenize=False, add_generation_prompt=True
)
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
text=text,
gconfig=self.gconfig,
gconfig=self.gconfig.new(n_samples=1),
)
resp = await engine.agenerate(req)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
seq = resp.input_tokens + resp.output_tokens
logprobs = [0] * resp.input_len + resp.output_logprobs
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0] * resp.input_len + resp.output_logprobs
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq)).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)
results.append(TensorDict(res, batch_size=[1]))
return TensorDict(res, batch_size=[1])
return concat_padded_tensors(results)

View File

@ -53,6 +53,9 @@ dependencies = [
"hydra-core==1.4.0.dev1",
"packaging",
"tabulate",
"torchdata",
"gymnasium",
"tensordict",
# Monitoring and logging
"wandb",

View File

@ -69,4 +69,7 @@ word2number
Pebble
timeout-decorator
prettytable
swanlab[dashboard]
swanlab[dashboard]
torchdata
gymnasium
tensordict