mirror of https://github.com/inclusionAI/AReaL
dataset loading fixed
This commit is contained in:
parent
f8586e47c8
commit
eb1e8a7592
|
@ -8,6 +8,7 @@ import os
|
|||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# NOTE: We don't sue wildcard importing here because the type
|
||||
# `Sequence` has a very similar name to `SequenceSample`.
|
||||
|
@ -43,6 +44,8 @@ from realhf.base.cluster import spec as cluster_spec
|
|||
|
||||
logger = logging.getLogger("api.data")
|
||||
|
||||
RL_TASKS = ["math", "code", "rlhf"]
|
||||
|
||||
|
||||
def load_hf_tokenizer(
|
||||
model_name_or_path: str,
|
||||
|
@ -529,6 +532,7 @@ class SequenceSample:
|
|||
"rewards",
|
||||
"greedy_rewards",
|
||||
"base_scores",
|
||||
"task_ids",
|
||||
]:
|
||||
return [[1] for _ in seqlens]
|
||||
elif key in [
|
||||
|
|
|
@ -85,7 +85,7 @@ class PromptDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
|
||||
class MATHPromptDataset(torch.utils.data.Dataset):
|
||||
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -93,21 +93,9 @@ class MATHPromptDataset(torch.utils.data.Dataset):
|
|||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
fill_to_max_length: bool = False,
|
||||
filter_threshold: float = 1e4,
|
||||
max_filter_percentage: float = 0.0,
|
||||
):
|
||||
"""A dataset with prompts. Usually used for PPO.
|
||||
|
||||
Args:
|
||||
util (api.data.DatasetUtility): .
|
||||
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
|
||||
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
|
||||
The json/jsonl file should be a list of dictionary. Each element in the list should have
|
||||
a key "prompt". Defaults to None.
|
||||
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
|
||||
A callable that returns a list of dictionary. Defaults to None.
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
|
@ -115,6 +103,7 @@ class MATHPromptDataset(torch.utils.data.Dataset):
|
|||
|
||||
prompts_str = [x["prompt"] for x in data]
|
||||
self.ids = [x["query_id"] for x in data]
|
||||
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||
util.tokenizer.padding_side = "left"
|
||||
|
@ -127,148 +116,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
|
|||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
if fill_to_max_length:
|
||||
for i in range(len(prompt_encodings["length"])):
|
||||
x = prompt_encodings["input_ids"][i]
|
||||
if max_length > len(x):
|
||||
# Fill with the final non-pad token to the left.
|
||||
prompt_encodings["input_ids"][i] = [x[-1]] * (
|
||||
max_length - len(x)
|
||||
) + x
|
||||
prompt_encodings["length"][i] = max_length
|
||||
|
||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
]
|
||||
logger.info(f"{len(indices)} samples remain")
|
||||
|
||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||
self.ids = [self.ids[idx] + f"@idx:{idx}-{util.dp_rank}" for idx in indices]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [self.base_scores[idx] for idx in indices]
|
||||
|
||||
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
||||
|
||||
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
||||
|
||||
self.active_indices = list(range(len(self.prompts)))
|
||||
self.filter_threshold = filter_threshold
|
||||
self.max_filter_percentage = max_filter_percentage
|
||||
|
||||
@property
|
||||
def util(self):
|
||||
return self._util
|
||||
|
||||
def __len__(self):
|
||||
return len(self.active_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# print(self.base_scores)
|
||||
idx = self.active_indices[idx]
|
||||
if hasattr(self, "base_scores"):
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
base_scores=torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
|
||||
),
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
# Get all data indices that have a higher score than the threshold.
|
||||
idx2scores_to_remove = {}
|
||||
for pop_idx, idx in enumerate(self.active_indices):
|
||||
data_id = self.ids[idx]
|
||||
if data_id not in eval_scores:
|
||||
continue
|
||||
if eval_scores[data_id] > self.filter_threshold:
|
||||
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
||||
|
||||
# Control the number of samples to be removed according to max_filter_percentage.
|
||||
n = int(len(self.active_indices) * self.max_filter_percentage)
|
||||
indices_to_remove = sorted(
|
||||
idx2scores_to_remove.keys(),
|
||||
key=lambda x: idx2scores_to_remove[x],
|
||||
reverse=True,
|
||||
)[:n]
|
||||
|
||||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||
self.active_indices.pop(pop_idx)
|
||||
logger.info(
|
||||
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||
f"Original dataset size: {len(self.prompts)}. "
|
||||
f"Filter threshold: {self.filter_threshold}. "
|
||||
f"Max filter percentage: {self.max_filter_percentage}. "
|
||||
f"Current number of eval scores: {len(eval_scores)}."
|
||||
)
|
||||
|
||||
|
||||
class CODEPromptDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
fill_to_max_length: bool = False,
|
||||
filter_threshold: float = 1e4,
|
||||
max_filter_percentage: float = 0.0,
|
||||
):
|
||||
"""A dataset with prompts. Usually used for PPO.
|
||||
|
||||
Args:
|
||||
util (api.data.DatasetUtility): .
|
||||
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
|
||||
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
|
||||
The json/jsonl file should be a list of dictionary. Each element in the list should have
|
||||
a key "prompt". Defaults to None.
|
||||
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
|
||||
A callable that returns a list of dictionary. Defaults to None.
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
prompts_str = [x["question"] for x in data]
|
||||
self.ids = [x["id"] for x in data]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||
util.tokenizer.padding_side = "left"
|
||||
prompt_encodings = util.tokenizer(
|
||||
prompts_str,
|
||||
truncation=True,
|
||||
# max_length=max_length,
|
||||
padding=False,
|
||||
return_length=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
if fill_to_max_length:
|
||||
for i in range(len(prompt_encodings["length"])):
|
||||
x = prompt_encodings["input_ids"][i]
|
||||
if max_length > len(x):
|
||||
# Fill with the final non-pad token to the left.
|
||||
prompt_encodings["input_ids"][i] = [x[-1]] * (
|
||||
max_length - len(x)
|
||||
) + x
|
||||
prompt_encodings["length"][i] = max_length
|
||||
|
||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
|
@ -300,32 +147,20 @@ class CODEPromptDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
# print(self.base_scores)
|
||||
print(
|
||||
f"CODEPromptDataset idx{idx}, use idx: {self.active_indices[idx]}, size: {len(self.ids)}"
|
||||
)
|
||||
idx = self.active_indices[idx]
|
||||
|
||||
data = dict(
|
||||
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
)
|
||||
if hasattr(self, "base_scores"):
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
base_scores=torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
else:
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
data["base_scores"] = torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
)
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=data,
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
# Get all data indices that have a higher score than the threshold.
|
||||
|
@ -348,7 +183,7 @@ class CODEPromptDataset(torch.utils.data.Dataset):
|
|||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||
self.active_indices.pop(pop_idx)
|
||||
logger.info(
|
||||
f"Code prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||
f"Original dataset size: {len(self.prompts)}. "
|
||||
f"Filter threshold: {self.filter_threshold}. "
|
||||
|
@ -359,12 +194,11 @@ class CODEPromptDataset(torch.utils.data.Dataset):
|
|||
|
||||
if not __name__ == "__main__":
|
||||
data_api.register_dataset("prompt", PromptDataset)
|
||||
data_api.register_dataset("math_prompt", MATHPromptDataset)
|
||||
data_api.register_dataset("code_prompt", CODEPromptDataset)
|
||||
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
|
||||
else:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
dataset = MATHPromptDataset(
|
||||
dataset = MATHCodePromptDataset(
|
||||
data_api.DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
|
@ -374,23 +208,17 @@ else:
|
|||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path="/storage/openpsi/data/math/Qwen_RL_training_xss/0101/train_rl@0101_with_qwqsft-7b_score.jsonl",
|
||||
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
|
||||
)
|
||||
|
||||
dataset = CODEPromptDataset(
|
||||
data_api.DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path="/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
)
|
||||
data_generator = enumerate(dataset)
|
||||
dataset_batch_counter, cur_sample = next(data_generator)
|
||||
print(
|
||||
f"size: {len(dataset)}, index: {dataset_batch_counter}, cur_sample: {cur_sample}"
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
print(f"size: {len(dataset)}")
|
||||
for d in dataloader:
|
||||
print(d.ids)
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import ModelInterfaceAbstraction
|
||||
from realhf.api.core.data_api import RL_TASKS, MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FusedThreadingForwardInterface(model_api.ModelInterface):
|
||||
|
||||
def __init__(self, interfaces: Dict[str, ModelInterfaceAbstraction]):
|
||||
self.interfaces = {
|
||||
key: model_api.make_interface(interface)
|
||||
for key, interface in interfaces.items()
|
||||
}
|
||||
|
||||
def run_interface(
|
||||
self,
|
||||
interface_name: str,
|
||||
model,
|
||||
data,
|
||||
mb_spec,
|
||||
) -> SequenceSample | None:
|
||||
tik = time.perf_counter()
|
||||
res = self.interfaces[interface_name].inference(model, data, mb_spec)
|
||||
t = time.perf_counter() - tik
|
||||
logger.info(f"Interface {interface_name} cost {t} s")
|
||||
return res
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
with ThreadPoolExecutor(max_workers=len(self.interfaces)) as executor:
|
||||
tasks = []
|
||||
for interface_name in self.interfaces:
|
||||
task = executor.submit(
|
||||
self.run_interface, interface_name, model, data, mb_spec
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
final_result = None
|
||||
for task in as_completed(tasks):
|
||||
res = task.result()
|
||||
if res is None:
|
||||
continue
|
||||
if final_result is None:
|
||||
final_result = res
|
||||
else:
|
||||
final_result.update_(res)
|
||||
|
||||
return final_result
|
||||
|
||||
# Mock methods for profiling only.
|
||||
def _mock_inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
) -> SequenceSample:
|
||||
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
||||
seqlens = [x + 1024 for x in prompt_lens]
|
||||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
module = module.module
|
||||
mconfig = module.config
|
||||
packed_input_ids = torch.randint(
|
||||
0,
|
||||
mconfig.vocab_size,
|
||||
(sum(seqlens),),
|
||||
dtype=torch.long,
|
||||
device=model.device,
|
||||
)
|
||||
n_tasks = len(RL_TASKS)
|
||||
task_ids = torch.randint(
|
||||
0, n_tasks, (data.bs,), dtype=torch.long, device=model.device
|
||||
)
|
||||
|
||||
return SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=data.ids,
|
||||
data=dict(packed_input_ids=packed_input_ids, task_ids=task_ids),
|
||||
)
|
||||
|
||||
|
||||
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)
|
|
@ -1,40 +1,41 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import ast
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import html
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import colorama
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from functioncall.code.verify import code_verify
|
||||
from functioncall.math.verify import math_verify
|
||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
||||
from realhf.api.core.data_api import (
|
||||
RL_TASKS,
|
||||
MicroBatchSpec,
|
||||
SequenceSample,
|
||||
load_hf_tokenizer,
|
||||
)
|
||||
from realhf.base import constants
|
||||
from realhf.base.constants import data_parallel_group, data_parallel_world_size
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.interface.math_parser import (
|
||||
parse_lines_in_parallel as math_verify_local,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
|
||||
|
||||
|
||||
class VerifierException(Exception):
|
||||
|
@ -53,7 +54,7 @@ def extract_python_code(text, min_length=20, strict_syntax=True):
|
|||
# verify code syntax
|
||||
if strict_syntax:
|
||||
try:
|
||||
parse(clean_block, mode="exec")
|
||||
ast.parse(clean_block, mode="exec")
|
||||
except (SyntaxError, IndentationError):
|
||||
continue
|
||||
|
||||
|
@ -118,40 +119,28 @@ def check_with_elementtree(text):
|
|||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
def reward_caculate(task, _answers, query_id_strs):
|
||||
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
||||
assert len(answers) == len(query_id_strs)
|
||||
format_rewards = []
|
||||
if task == "math":
|
||||
format_rewards = math_verify_call(_answers, query_id_strs)
|
||||
else:
|
||||
codes = [extract_python_code(_answer) for _answer in _answers]
|
||||
format_rewards = math_verify_call(answers, query_id_strs)
|
||||
elif task == "code":
|
||||
codes = [extract_python_code(_answer) for _answer in answers]
|
||||
format_rewards = code_verify(codes, query_id_strs)
|
||||
logger.info(
|
||||
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
|
||||
)
|
||||
assert len(format_rewards) == len(answers), task
|
||||
return format_rewards
|
||||
|
||||
|
||||
def retokenize(
|
||||
def retokenize_and_verify(
|
||||
task,
|
||||
tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
prompt_ids: List[List[int]],
|
||||
seq_ids: List[List[int]],
|
||||
query_ids: List[str],
|
||||
check_xml_format=False,
|
||||
do_eval=False,
|
||||
):
|
||||
input_ids = [
|
||||
packed_input_ids[start:end]
|
||||
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
||||
]
|
||||
prompt_ids = [
|
||||
prompts[start:end]
|
||||
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
||||
]
|
||||
seq_strs = tokenizer.batch_decode(
|
||||
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
seq_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
prompt_strs = tokenizer.batch_decode(
|
||||
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
|
@ -159,250 +148,131 @@ def retokenize(
|
|||
# query_id_strs = query_ids
|
||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||
|
||||
logger.info(
|
||||
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
|
||||
)
|
||||
answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
|
||||
format_rewards = []
|
||||
queryid_to_results = collections.defaultdict(list)
|
||||
# 8 processes on each node, with 10 subprocesses each
|
||||
if do_eval == True:
|
||||
_answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
format_rewards = dispatch_reward_calculation(task, answers, query_id_strs)
|
||||
|
||||
format_rewards = math_verify_call(_answers, query_id_strs)
|
||||
if check_xml_format:
|
||||
with ThreadPoolExecutor(max_workers=22) as executor:
|
||||
futures = [
|
||||
executor.submit(check_with_elementtree, answer_str)
|
||||
for answer_str in _answers
|
||||
]
|
||||
# xml_rewards = []
|
||||
for idx, future in enumerate(futures):
|
||||
xml_reward, _ = future.result()
|
||||
# xml_rewards.append(xml_reward)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
if check_xml_format:
|
||||
for idx, answer in enumerate(answers):
|
||||
xml_reward, _ = check_with_elementtree(answer)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
|
||||
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(format_reward)
|
||||
else:
|
||||
for query_id_str in query_id_strs:
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(0)
|
||||
format_rewards.append(0)
|
||||
|
||||
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
||||
return format_rewards, prompt_strs, seq_strs
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackedRewardInterface(model_api.ModelInterface):
|
||||
enable_save: bool = False
|
||||
class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
rm_output_scaling: float = 1.0
|
||||
rm_output_bias: float = 0.0
|
||||
output_scaling: float = 1.0
|
||||
output_bias: float = 0.0
|
||||
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
max_sync_length: int = 2048
|
||||
rw_type: str = "sparse"
|
||||
task: str = "math" # math or countdown
|
||||
check_xml_format: bool = False
|
||||
post_process: str = "sigmoid"
|
||||
group_size: int = 1
|
||||
check_verifier_status: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
||||
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"max_sync_length: {self.max_sync_length}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
logger.info(f"post_process: {self.post_process}")
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
if (
|
||||
constants.data_parallel_world_size()
|
||||
< constants.parallelism_group_size()
|
||||
):
|
||||
logger.warning(
|
||||
"There's no reason to use tensor and pipeline parallelism for CPU reward."
|
||||
)
|
||||
|
||||
def inference(
|
||||
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
|
||||
xs = data.unpack()
|
||||
dispatched = {}
|
||||
dispatched_indices = {}
|
||||
for task_idx, task_name in enumerate(RL_TASKS):
|
||||
indices = (data.data["task_ids"] == task_idx).numpy().tolist()
|
||||
if any(indices):
|
||||
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
||||
dispatched_indices[task_name] = indices
|
||||
|
||||
return dispatched, dispatched_indices
|
||||
|
||||
def _gather_tasks(
|
||||
self, results: Dict, dispatched_indices: Dict, bs: int
|
||||
) -> SequenceSample:
|
||||
xs = [None for _ in range(bs)]
|
||||
for task_name, indices in dispatched_indices.items():
|
||||
xxs = results[task_name].unpack()
|
||||
for i, xx in zip(indices, xxs):
|
||||
xs[i] = xx
|
||||
assert all(xs)
|
||||
return SequenceSample.gather(xs)
|
||||
|
||||
def calculate_task_reward(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data_: SequenceSample,
|
||||
mb_spec,
|
||||
) -> SequenceSample:
|
||||
|
||||
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
||||
|
||||
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
||||
input_cu_seqlens = torch.nn.functional.pad(
|
||||
input_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
packed_prompts = data_.data["packed_prompts"]
|
||||
prompts = []
|
||||
prompt_seqlens = []
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
task_type: str,
|
||||
):
|
||||
# mb_spec is disrespected here
|
||||
packed_input_ids: torch.Tensor = data.data["packed_input_ids"]
|
||||
input_seqlens = flat2d(data.seqlens["packed_input_ids"])
|
||||
seq_ids = []
|
||||
offset = 0
|
||||
for x in data_.seqlens["packed_prompts"]:
|
||||
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
||||
offset += x[0]
|
||||
prompt_seqlens.extend(x * self.group_size)
|
||||
|
||||
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
||||
# non_packed_prompts = copy.deepcopy(prompts)
|
||||
prompts = torch.cat(prompts)
|
||||
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
||||
prompt_cu_seqlens = torch.nn.functional.pad(
|
||||
prompt_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
||||
|
||||
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
||||
retokenize(
|
||||
self.task,
|
||||
self.tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=self.check_xml_format,
|
||||
do_eval=constants.is_last_pipe_stage(),
|
||||
for slen in input_seqlens:
|
||||
seq_ids.append(
|
||||
packed_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
||||
)
|
||||
offset += slen
|
||||
assert offset == packed_input_ids.shape[0], (offset, packed_input_ids.shape)
|
||||
prompt_input_ids = data.data["packed_prompts"]
|
||||
prompt_len = flat2d(data.seqlens["packed_prompts"])
|
||||
prompt_ids = []
|
||||
offset = 0
|
||||
for slen in prompt_len:
|
||||
p = prompt_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
||||
prompt_ids += [p] * self.group_size
|
||||
offset += slen
|
||||
format_rewards, prompt_strs, seq_strs = retokenize_and_verify(
|
||||
task_type,
|
||||
self.tokenizer,
|
||||
prompt_ids=prompt_ids,
|
||||
seq_ids=seq_ids,
|
||||
query_ids=[
|
||||
str(data_id) for data_id in data.ids for _ in range(self.group_size)
|
||||
],
|
||||
check_xml_format=self.check_xml_format,
|
||||
)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
dense_scores = torch.zeros_like(packed_input_ids).float()
|
||||
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
||||
scores[scores == 0] = -1
|
||||
|
||||
if len(scores) == 0:
|
||||
return None
|
||||
|
||||
assert dense_scores.shape == packed_input_ids.shape
|
||||
scores = (
|
||||
scores.to(packed_input_ids.device) - self.output_bias
|
||||
) * self.output_scaling
|
||||
|
||||
logger.info(f"Math reward logging info @v{model.version.global_step}")
|
||||
logger.info(
|
||||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
if constants.is_last_pipe_stage():
|
||||
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
model.inc_version()
|
||||
|
||||
if scores.dtype != torch.float32:
|
||||
scores = scores.to(torch.float32)
|
||||
if dense_scores.dtype != torch.float32:
|
||||
dense_scores = dense_scores.to(torch.float32)
|
||||
# NOTE: a place holder
|
||||
dense_scores = packed_input_ids.new_zeros(dtype=torch.float32)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["rewards", "dense_rewards"],
|
||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
||||
ids=data_.ids,
|
||||
ids=data.ids,
|
||||
seqlens=dict(
|
||||
rewards=[
|
||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
||||
for x in data_.seqlens["packed_input_ids"]
|
||||
for x in data.seqlens["packed_input_ids"]
|
||||
],
|
||||
dense_rewards=data_.seqlens["packed_input_ids"],
|
||||
dense_rewards=data.seqlens["packed_input_ids"],
|
||||
),
|
||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
||||
)
|
||||
|
@ -410,13 +280,13 @@ class PackedRewardInterface(model_api.ModelInterface):
|
|||
# record rewards for each piece of data
|
||||
avg_scores = []
|
||||
offset = 0
|
||||
for i in range(data_.bs):
|
||||
for i in range(data.bs):
|
||||
score_lis = scores[
|
||||
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
||||
offset : offset + len(data.seqlens["packed_input_ids"][i])
|
||||
]
|
||||
avg_scores.append(score_lis.mean().item())
|
||||
offset += len(data_.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
||||
offset += len(data.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data.seqlens["packed_input_ids"])
|
||||
|
||||
res.metadata["scores"] = avg_scores
|
||||
|
||||
|
@ -425,20 +295,120 @@ class PackedRewardInterface(model_api.ModelInterface):
|
|||
np.mean(avg_scores), device=constants.current_device()
|
||||
)
|
||||
dist.all_reduce(
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
|
||||
)
|
||||
avg_score /= constants.parallelism_group_size()
|
||||
avg_score /= constants.data_parallel_group()
|
||||
avg_score = avg_score.item()
|
||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||
minimal_score = (-1 - self.output_bias) * self.output_scaling
|
||||
|
||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||
raise VerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
|
||||
if not constants.is_last_pipe_stage():
|
||||
return None
|
||||
return res
|
||||
|
||||
def log_rewards_to_file(
|
||||
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
|
||||
):
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
task_type,
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||
)
|
||||
|
||||
model_api.register_interface("reward", PackedRewardInterface)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
task_type,
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
task_data, dispatch_indices = self._dispatch_tasks(data)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
|
||||
def _task_func(func, task_type: str):
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise asyncio.CancelledError(f"{task_type} task failed: {e}") from e
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
|
||||
return result
|
||||
|
||||
return _wrapped_func
|
||||
|
||||
async def _run_tasks():
|
||||
tasks = []
|
||||
for task_type, d in task_data.items():
|
||||
task_func = _task_func(self.calculate_task_reward, task_type)
|
||||
task_args = (model, d, mb_spec, task_type)
|
||||
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||
tasks.append((task_type, task))
|
||||
|
||||
task_results = {}
|
||||
results = await asyncio.gather(*tasks)
|
||||
for task_type, result in results:
|
||||
task_results[task_type] = result
|
||||
|
||||
return task_results
|
||||
|
||||
if constants.is_dp_head():
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
|
||||
else:
|
||||
final_result = None
|
||||
|
||||
model.inc_version()
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
model_api.register_interface("reward", MultiTaskCPURewardInterface)
|
||||
|
|
|
@ -1,266 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import time
|
||||
from typing import Dict, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.utils.functional import (
|
||||
gather_packed_shifted_log_probs,
|
||||
masked_normalization,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("RefRwInterface")
|
||||
|
||||
TASK_TYPE_REF: Literal["ref"] = "ref"
|
||||
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
|
||||
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefRwInterface(model_api.ModelInterface):
|
||||
n_minibatches: int = 4
|
||||
|
||||
# Use dict here to allow argument passing through commandline.
|
||||
generation_config: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
kl_ctl: float = 0.1
|
||||
|
||||
adv_norm: bool = True
|
||||
discount: float = 1.0
|
||||
gae_lambda: float = 1.0
|
||||
|
||||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
max_reward_clip: float = 5.0
|
||||
|
||||
disable_value: bool = False
|
||||
|
||||
early_stop_kl: Optional[float] = None # e.g. 0.1
|
||||
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
|
||||
|
||||
adaptive_kl_ctl: bool = False
|
||||
adaptive_kl_target: Optional[float] = 6
|
||||
adaptive_kl_horizon: Optional[float] = 10000
|
||||
|
||||
enable_save: bool = True
|
||||
|
||||
value_norm: bool = False
|
||||
value_norm_type: str = dataclasses.field(
|
||||
metadata={"choices": ["exp", "ma"]}, default="exp"
|
||||
)
|
||||
value_norm_beta: float = 0.99995
|
||||
value_norm_eps: float = 1e-5
|
||||
|
||||
group_size: int = 1
|
||||
generation_size: Optional[int] = None
|
||||
mask_no_eos_with_zero: bool = False
|
||||
group_adv_norm: bool = False
|
||||
mask_too_long: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.adaptive_kl_ctl:
|
||||
assert self.adaptive_kl_target is not None
|
||||
assert self.adaptive_kl_horizon is not None
|
||||
self.kl_adapter = ppo_functional.AdaptiveKLController(
|
||||
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
|
||||
)
|
||||
else:
|
||||
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
|
||||
if self.value_norm:
|
||||
from realhf.impl.model.modules import (
|
||||
ExponentialRunningMeanStd,
|
||||
MovingAverageRunningMeanStd,
|
||||
)
|
||||
|
||||
if self.value_norm_type == "exp":
|
||||
self.rms = ExponentialRunningMeanStd(
|
||||
beta=self.value_norm_beta, epsilon=self.value_norm_eps
|
||||
)
|
||||
elif self.value_norm_type == "ma":
|
||||
self.rms = MovingAverageRunningMeanStd()
|
||||
else:
|
||||
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
|
||||
self.kl_ctl = None
|
||||
|
||||
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
|
||||
if self.generation_size is not None:
|
||||
assert self.generation_size >= self.group_size
|
||||
else:
|
||||
self.generation_size = self.group_size
|
||||
self.gconfig.n = self.generation_size
|
||||
|
||||
def save(self, model: model_api.Model, save_dir: str):
|
||||
if not self.enable_save:
|
||||
return
|
||||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
module = module.module
|
||||
module.save_to_hf(
|
||||
tokenizer=model.tokenizer,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
def _dispatch_tasks(self, data):
|
||||
math_data, code_data, rlhf_data, ref_data = data, data, data, data
|
||||
return math_data, code_data, rlhf_data, ref_data
|
||||
|
||||
def _gather_tasks(self, data_map):
|
||||
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
|
||||
return data_map.get(TASK_TYPE_REF, None)
|
||||
|
||||
@torch.no_grad()
|
||||
def ref_inference(
|
||||
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
|
||||
):
|
||||
module = model.module
|
||||
module.eval()
|
||||
|
||||
# This post_hook will gather log probabilities in mini-batches,
|
||||
# reducing peak memory usage.
|
||||
def calc_logprobs(logits, input_):
|
||||
logits /= self.gconfig.temperature
|
||||
|
||||
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
|
||||
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
||||
|
||||
logprobs = gather_packed_shifted_log_probs(
|
||||
logits, cu_seqlens, input_.data["packed_input_ids"]
|
||||
)
|
||||
return logprobs
|
||||
|
||||
input_flattend = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
|
||||
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
||||
)
|
||||
# add posthook to avoid storing full logits
|
||||
logprobs = module.forward(
|
||||
input_=input_flattend,
|
||||
post_hook=calc_logprobs,
|
||||
output_seqlens=[
|
||||
[x - 1 for x in slens]
|
||||
for slens in input_flattend.seqlens["packed_input_ids"]
|
||||
],
|
||||
mb_spec=mb_spec,
|
||||
)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["packed_ref_logprobs"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||
data=dict(packed_ref_logprobs=logprobs),
|
||||
seqlens=dict(
|
||||
packed_ref_logprobs=[
|
||||
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample:
|
||||
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
|
||||
|
||||
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
|
||||
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
|
||||
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
|
||||
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
|
||||
|
||||
task_map = {
|
||||
TASK_TYPE_REF: (self.ref_inference, ref_data),
|
||||
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
|
||||
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
|
||||
}
|
||||
|
||||
def _task_func(func, task_type: str):
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"{task_type} ref_rw task failed: {e}")
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
logger.info(
|
||||
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
|
||||
)
|
||||
return result
|
||||
|
||||
return _wrapped_func
|
||||
|
||||
async def _run_tasks() -> dict:
|
||||
tasks = []
|
||||
for task_type, (func, data) in task_map.items():
|
||||
if not data:
|
||||
continue
|
||||
task_func = _task_func(func, task_type)
|
||||
task_args = (model, data, mb_spec)
|
||||
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||
tasks.append((task_type, task))
|
||||
|
||||
results = {}
|
||||
for task_type, task in tasks:
|
||||
try:
|
||||
results[task_type] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"{task_type} task failed: {e}")
|
||||
results[task_type] = None
|
||||
return results
|
||||
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results)
|
||||
return final_result
|
||||
|
||||
# Mock methods for profiling only.
|
||||
def _mock_inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
dataset_input: SequenceSample,
|
||||
) -> SequenceSample:
|
||||
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
|
||||
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
|
||||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
module = module.module
|
||||
mconfig = module.config
|
||||
packed_input_ids = torch.randint(
|
||||
0,
|
||||
mconfig.vocab_size,
|
||||
(sum(seqlens),),
|
||||
dtype=torch.long,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
return SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=dataset_input.ids,
|
||||
data=dict(packed_input_ids=packed_input_ids),
|
||||
)
|
||||
|
||||
|
||||
model_api.register_interface("ref_rw", RefRwInterface)
|
|
@ -1,317 +0,0 @@
|
|||
import functools
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import pynvml
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from torch.cuda import is_initialized
|
||||
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
print(root_dir)
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, root_dir)
|
||||
|
||||
from realhf.api.core import data_api, dfg, model_api
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.network import find_free_port
|
||||
from realhf.base.testing import (
|
||||
_DEFAULT_EXPR_NAME,
|
||||
_DEFAULT_TRIAL_NAME,
|
||||
init_global_constants,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("test async ref-rew")
|
||||
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
|
||||
|
||||
|
||||
def loadJson():
|
||||
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
|
||||
with open(dataDir, "r") as f:
|
||||
if dataDir.endswith(".jsonl"):
|
||||
samples = [json.loads(line) for line in f.readlines()]
|
||||
else:
|
||||
samples = json.load(f)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def _mock_input(batch_size: int, seq_len):
|
||||
vocab_size = 100
|
||||
torch.manual_seed(1)
|
||||
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
||||
|
||||
samples = loadJson()
|
||||
id_list = list(samples.keys())
|
||||
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
|
||||
|
||||
return data_api.SequenceSample.from_default(
|
||||
seqlens=[seq_len for _ in range(seqs.shape[0])],
|
||||
ids=[id_list[i] for i in range(seqs.shape[0])],
|
||||
data=dict(
|
||||
packed_input_ids=seqs.view(-1),
|
||||
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
|
||||
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def funcion_call(
|
||||
rpc_name: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model_path: str,
|
||||
model_family_name: str,
|
||||
dp: int,
|
||||
pp: int,
|
||||
tp: int,
|
||||
interface_type: dfg.ModelInterfaceType,
|
||||
interface_impl: dfg.ModelInterfaceAbstraction,
|
||||
batch_size: int,
|
||||
prompt_len: int,
|
||||
input_: data_api.SequenceSample | None,
|
||||
port: int,
|
||||
):
|
||||
|
||||
# assert not torch.cuda.is_initialized()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
torch.cuda.set_device(0)
|
||||
assert world_size == (
|
||||
dp * pp * tp
|
||||
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
||||
assert batch_size % dp == 0, (batch_size, dp)
|
||||
|
||||
# Initialize distributed environment.
|
||||
model_name = ModelName("default", 0)
|
||||
if not dist.is_initialized():
|
||||
logger.info("Setting up distributed environment...")
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method=f"tcp://localhost:{port}",
|
||||
)
|
||||
logger.info("Initialized distributed environment.")
|
||||
init_global_constants(
|
||||
num_dp=dp,
|
||||
num_mp=tp,
|
||||
num_pp=pp,
|
||||
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
model_name=model_name,
|
||||
max_prompt_len=prompt_len,
|
||||
)
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# NOTE: import here to avoid CUDA re-initialization
|
||||
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
||||
|
||||
# Call a method like `config_from_llama` to get the config.
|
||||
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
||||
transformers.AutoConfig.from_pretrained(model_path)
|
||||
)
|
||||
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
|
||||
mconfig.is_critic = is_critic
|
||||
with constants.model_scope(model_name):
|
||||
# Construct the model.
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
|
||||
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
|
||||
setattr(
|
||||
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
|
||||
)
|
||||
module._instantiation_hooks.append(
|
||||
lambda: getattr(module, f"from_{model_family_name}")(
|
||||
load_dir=model_path,
|
||||
init_critic_from_actor=is_critic,
|
||||
)
|
||||
)
|
||||
add_helper_functions(module)
|
||||
module.instantiate()
|
||||
module.eval()
|
||||
|
||||
tokenizer = data_api.load_hf_tokenizer(model_path)
|
||||
|
||||
model = model_api.Model(
|
||||
name=model_name,
|
||||
module=module,
|
||||
tokenizer=tokenizer,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
from realhf.impl.model.backend.megatron import MegatronTrainBackend
|
||||
|
||||
backend = MegatronTrainBackend()
|
||||
else:
|
||||
from realhf.impl.model.backend.inference import PipelineInferenceBackend
|
||||
|
||||
backend = PipelineInferenceBackend()
|
||||
|
||||
logger.info("Running backend initialization...")
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=1,
|
||||
dataset_size=128,
|
||||
train_batch_size=128,
|
||||
)
|
||||
model = backend.initialize(model, ft_spec)
|
||||
|
||||
interface = model_api.make_interface(interface_impl)
|
||||
|
||||
if input_ is None:
|
||||
input_ = _mock_input(batch_size, prompt_len)
|
||||
|
||||
input_ = input_.cuda()
|
||||
|
||||
mb_spec = model_api.MicroBatchSpec()
|
||||
|
||||
logger.info("Running interface computation...")
|
||||
start = time.perf_counter_ns()
|
||||
if interface_type == dfg.ModelInterfaceType.GENERATE:
|
||||
res = interface.generate(model, input_, mb_spec)
|
||||
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
res = interface.train_step(model, input_)
|
||||
else:
|
||||
res = interface.inference(model, input_, mb_spec)
|
||||
|
||||
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
res = res.cpu()
|
||||
|
||||
comsumed = time.perf_counter_ns() - start
|
||||
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
|
||||
return res
|
||||
|
||||
|
||||
def run_function_call(
|
||||
rpc_name: str,
|
||||
model_path: str,
|
||||
model_family_name: str,
|
||||
batch_size: int,
|
||||
prompt_len: int,
|
||||
gen_len: int,
|
||||
input_: data_api.SequenceSample | None,
|
||||
) -> data_api.SequenceSample | None:
|
||||
assert rpc_name in [
|
||||
"actor_gen",
|
||||
"actor_train",
|
||||
"critic_inf",
|
||||
"rew_inf",
|
||||
"critic_train",
|
||||
"ref_inf",
|
||||
"ref_rw",
|
||||
]
|
||||
|
||||
ref_rw_interface = dfg.ModelInterfaceAbstraction(
|
||||
"ref_rw",
|
||||
args=dict(
|
||||
generation_config=dict(
|
||||
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||
),
|
||||
rew_inf_args=dict(
|
||||
tokenizer_path=model_path,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args=dict(
|
||||
generation_config=dict(
|
||||
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||
),
|
||||
rew_inf_args=dict(
|
||||
tokenizer_path=model_path,
|
||||
),
|
||||
),
|
||||
)
|
||||
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
|
||||
rw_interface = dfg.ModelInterfaceAbstraction(
|
||||
"paired_rw",
|
||||
)
|
||||
if rpc_name == "actor_gen":
|
||||
interface_type = dfg.ModelInterfaceType.GENERATE
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "actor_train":
|
||||
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "critic_inf":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ppo_critic_interface
|
||||
elif rpc_name == "ref_inf":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "ref_rw":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ref_rw_interface
|
||||
elif rpc_name == "critic_train":
|
||||
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||
interface_impl = ppo_critic_interface
|
||||
else:
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = rw_interface
|
||||
|
||||
logger.info(f"Running RPC {rpc_name}...")
|
||||
|
||||
port = find_free_port()
|
||||
res = funcion_call(
|
||||
rank=0,
|
||||
rpc_name=rpc_name,
|
||||
world_size=1,
|
||||
model_path=model_path,
|
||||
model_family_name=model_family_name,
|
||||
dp=1,
|
||||
pp=1,
|
||||
tp=1,
|
||||
interface_type=interface_type,
|
||||
interface_impl=interface_impl,
|
||||
batch_size=batch_size,
|
||||
prompt_len=prompt_len,
|
||||
input_=input_,
|
||||
port=port,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
return res
|
||||
else:
|
||||
logger.info(f"RPC {rpc_name} stats: {res}")
|
||||
|
||||
|
||||
def main():
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
model_family_name = "qwen2"
|
||||
batch_size = 16
|
||||
prompt_len = 128
|
||||
gen_len = 4096
|
||||
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
|
||||
|
||||
for i in range(2):
|
||||
ref_rw_res = run_function_call(
|
||||
"ref_rw",
|
||||
model_family_name=model_family_name,
|
||||
model_path=model_path,
|
||||
batch_size=batch_size,
|
||||
prompt_len=prompt_len,
|
||||
gen_len=gen_len,
|
||||
input_=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue