mirror of https://github.com/inclusionAI/AReaL
115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
import asyncio
|
|
import os
|
|
import uuid
|
|
|
|
import colorama
|
|
import torch
|
|
from tensordict import TensorDict
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
from arealite.api.cli_args import GenerationHyperparameters
|
|
from arealite.api.engine_api import InferenceEngine
|
|
from arealite.api.io_struct import LLMRequest
|
|
from arealite.api.workflow_api import RolloutWorkflow
|
|
from arealite.utils.data import concat_padded_tensors
|
|
|
|
|
|
class RLVRWorkflow(RolloutWorkflow):
|
|
def __init__(
|
|
self,
|
|
reward_fn,
|
|
gconfig: GenerationHyperparameters,
|
|
tokenizer: PreTrainedTokenizerFast,
|
|
enable_thinking: bool,
|
|
dump_dir: str | None = None,
|
|
):
|
|
self.reward_fn = reward_fn
|
|
self.gconfig = gconfig
|
|
self.tokenizer = tokenizer
|
|
self.enable_thinking = enable_thinking
|
|
self.dump_dir = dump_dir
|
|
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
|
os.makedirs(self.dump_dir, exist_ok=True)
|
|
|
|
async def arun_episode(self, engine: InferenceEngine, data):
|
|
input_ids = self.tokenizer.apply_chat_template(
|
|
data["messages"],
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
enable_thinking=self.enable_thinking,
|
|
)
|
|
n_samples = self.gconfig.n_samples
|
|
req = LLMRequest(
|
|
rid=uuid.uuid4().hex,
|
|
input_ids=input_ids,
|
|
gconfig=self.gconfig.new(n_samples=1),
|
|
)
|
|
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
|
|
|
version = engine.get_version()
|
|
prompt_strs = []
|
|
completions_strs = []
|
|
rewards = []
|
|
seqlens = []
|
|
|
|
results = []
|
|
for resp in resps:
|
|
seq = resp.input_tokens + resp.output_tokens
|
|
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
|
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
|
versions = [-1] * resp.input_len + resp.output_versions
|
|
|
|
prompt_str = self.tokenizer.decode(input_ids)
|
|
completions_str = self.tokenizer.decode(resp.output_tokens)
|
|
prompt_strs.append(prompt_str)
|
|
completions_strs.append(completions_str)
|
|
seqlens.append(len(seq))
|
|
reward = self.reward_fn(
|
|
prompt=prompt_str,
|
|
completions=completions_str,
|
|
prompt_ids=resp.input_tokens,
|
|
completion_ids=resp.output_tokens,
|
|
**data,
|
|
)
|
|
rewards.append(reward)
|
|
res = dict(
|
|
# unsqueeze to add an additional batch dimension
|
|
input_ids=torch.tensor(seq).unsqueeze(0),
|
|
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
|
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
|
versions=torch.tensor(versions).unsqueeze(0),
|
|
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
|
# reward
|
|
rewards=torch.tensor([float(reward)]),
|
|
)
|
|
results.append(TensorDict(res, batch_size=[1]))
|
|
|
|
if self.dump_dir is not None:
|
|
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
|
# Get the unique identifier for this prompt
|
|
qid = None
|
|
for key in ["query_id", "id", "qid"]:
|
|
qid = data.get(key, None)
|
|
if qid is not None:
|
|
break
|
|
qid = qid or uuid.uuid4().hex
|
|
|
|
# Dump rollout to file
|
|
with open(
|
|
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
|
) as f:
|
|
n_samples = self.gconfig.n_samples
|
|
for i, (p, c, r, sl) in enumerate(
|
|
zip(prompt_strs, completions_strs, rewards, seqlens)
|
|
):
|
|
info = "\n".join(
|
|
[
|
|
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
|
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
|
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
|
]
|
|
)
|
|
f.write(info + "\n")
|
|
|
|
return concat_padded_tensors(results)
|