mirror of https://github.com/inclusionAI/AReaL
0724_merge8
This commit is contained in:
parent
6255ad5aa7
commit
b8549ac48a
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||
|
@ -47,7 +48,13 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
|
@ -55,14 +62,19 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
prompt=self.tokenizer.decode(input_ids),
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
|
@ -76,4 +88,32 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
if self.dump_dir is not None:
|
||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||
# Get the unique identifier for this prompt
|
||||
qid = None
|
||||
for key in ["query_id", "id", "qid"]:
|
||||
qid = data.get(key, None)
|
||||
if qid is not None:
|
||||
break
|
||||
qid = qid or uuid.uuid4().hex
|
||||
|
||||
# Dump rollout to file
|
||||
with open(
|
||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||
) as f:
|
||||
n_samples = self.gconfig.n_samples
|
||||
for i, (p, c, r, sl) in enumerate(
|
||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||
]
|
||||
)
|
||||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
|
|
|
@ -141,6 +141,9 @@ def main(args):
|
|||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
enable_thinking=False,
|
||||
dump_dir=os.path.join(
|
||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||
),
|
||||
)
|
||||
|
||||
# Run training.
|
||||
|
|
|
@ -3,8 +3,6 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
|
|
Loading…
Reference in New Issue