dataset loading fixed

This commit is contained in:
bowei.fw 2025-03-20 21:52:01 +08:00
parent f8586e47c8
commit eb1e8a7592
6 changed files with 356 additions and 1036 deletions

View File

@ -8,6 +8,7 @@ import os
import random
import time
from contextlib import contextmanager
from enum import Enum
# NOTE: We don't sue wildcard importing here because the type
# `Sequence` has a very similar name to `SequenceSample`.
@ -43,6 +44,8 @@ from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("api.data")
RL_TASKS = ["math", "code", "rlhf"]
def load_hf_tokenizer(
model_name_or_path: str,
@ -529,6 +532,7 @@ class SequenceSample:
"rewards",
"greedy_rewards",
"base_scores",
"task_ids",
]:
return [[1] for _ in seqlens]
elif key in [

View File

@ -85,7 +85,7 @@ class PromptDataset(torch.utils.data.Dataset):
)
class MATHPromptDataset(torch.utils.data.Dataset):
class MATHCodePromptDataset(torch.utils.data.Dataset):
def __init__(
self,
@ -93,21 +93,9 @@ class MATHPromptDataset(torch.utils.data.Dataset):
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
fill_to_max_length: bool = False,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""A dataset with prompts. Usually used for PPO.
Args:
util (api.data.DatasetUtility): .
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
The json/jsonl file should be a list of dictionary. Each element in the list should have
a key "prompt". Defaults to None.
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
A callable that returns a list of dictionary. Defaults to None.
"""
self._util = util
self.max_length = max_length
@ -115,6 +103,7 @@ class MATHPromptDataset(torch.utils.data.Dataset):
prompts_str = [x["prompt"] for x in data]
self.ids = [x["query_id"] for x in data]
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
@ -127,148 +116,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
return_attention_mask=False,
)
if fill_to_max_length:
for i in range(len(prompt_encodings["length"])):
x = prompt_encodings["input_ids"][i]
if max_length > len(x):
# Fill with the final non-pad token to the left.
prompt_encodings["input_ids"][i] = [x[-1]] * (
max_length - len(x)
) + x
prompt_encodings["length"][i] = max_length
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [self.ids[idx] + f"@idx:{idx}-{util.dp_rank}" for idx in indices]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
idx = self.active_indices[idx]
if hasattr(self, "base_scores"):
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
base_scores=torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
),
),
)
else:
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
),
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
class CODEPromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
fill_to_max_length: bool = False,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""A dataset with prompts. Usually used for PPO.
Args:
util (api.data.DatasetUtility): .
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
The json/jsonl file should be a list of dictionary. Each element in the list should have
a key "prompt". Defaults to None.
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
A callable that returns a list of dictionary. Defaults to None.
"""
self._util = util
self.max_length = max_length
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["question"] for x in data]
self.ids = [x["id"] for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
if fill_to_max_length:
for i in range(len(prompt_encodings["length"])):
x = prompt_encodings["input_ids"][i]
if max_length > len(x):
# Fill with the final non-pad token to the left.
prompt_encodings["input_ids"][i] = [x[-1]] * (
max_length - len(x)
) + x
prompt_encodings["length"][i] = max_length
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
@ -300,32 +147,20 @@ class CODEPromptDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
# print(self.base_scores)
print(
f"CODEPromptDataset idx{idx}, use idx: {self.active_indices[idx]}, size: {len(self.ids)}"
)
idx = self.active_indices[idx]
data = dict(
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
)
if hasattr(self, "base_scores"):
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
base_scores=torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
),
),
metadata=dict(random_id=[uuid.uuid4()]),
)
else:
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
),
metadata=dict(random_id=[uuid.uuid4()]),
data["base_scores"] = torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
)
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=data,
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
@ -348,7 +183,7 @@ class CODEPromptDataset(torch.utils.data.Dataset):
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Code prompt dataset DP rank {self.util.dp_rank} filtered"
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
@ -359,12 +194,11 @@ class CODEPromptDataset(torch.utils.data.Dataset):
if not __name__ == "__main__":
data_api.register_dataset("prompt", PromptDataset)
data_api.register_dataset("math_prompt", MATHPromptDataset)
data_api.register_dataset("code_prompt", CODEPromptDataset)
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
else:
from transformers import AutoTokenizer
dataset = MATHPromptDataset(
dataset = MATHCodePromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
@ -374,23 +208,17 @@ else:
),
),
max_length=512,
dataset_path="/storage/openpsi/data/math/Qwen_RL_training_xss/0101/train_rl@0101_with_qwqsft-7b_score.jsonl",
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
)
dataset = CODEPromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path="/storage/datasets/codeparrot-apps-test.jsonl",
)
data_generator = enumerate(dataset)
dataset_batch_counter, cur_sample = next(data_generator)
print(
f"size: {len(dataset)}, index: {dataset_batch_counter}, cur_sample: {cur_sample}"
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
print(f"size: {len(dataset)}")
for d in dataloader:
print(d.ids)

View File

@ -0,0 +1,101 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from realhf.api.core.config import ModelInterfaceAbstraction
from realhf.api.core.data_api import RL_TASKS, MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class FusedThreadingForwardInterface(model_api.ModelInterface):
def __init__(self, interfaces: Dict[str, ModelInterfaceAbstraction]):
self.interfaces = {
key: model_api.make_interface(interface)
for key, interface in interfaces.items()
}
def run_interface(
self,
interface_name: str,
model,
data,
mb_spec,
) -> SequenceSample | None:
tik = time.perf_counter()
res = self.interfaces[interface_name].inference(model, data, mb_spec)
t = time.perf_counter() - tik
logger.info(f"Interface {interface_name} cost {t} s")
return res
def inference(
self,
model: model_api.Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
with ThreadPoolExecutor(max_workers=len(self.interfaces)) as executor:
tasks = []
for interface_name in self.interfaces:
task = executor.submit(
self.run_interface, interface_name, model, data, mb_spec
)
tasks.append(task)
final_result = None
for task in as_completed(tasks):
res = task.result()
if res is None:
continue
if final_result is None:
final_result = res
else:
final_result.update_(res)
return final_result
# Mock methods for profiling only.
def _mock_inference(
self,
model: model_api.Model,
data: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(data.seqlens["packed_prompts"])
seqlens = [x + 1024 for x in prompt_lens]
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
mconfig = module.config
packed_input_ids = torch.randint(
0,
mconfig.vocab_size,
(sum(seqlens),),
dtype=torch.long,
device=model.device,
)
n_tasks = len(RL_TASKS)
task_ids = torch.randint(
0, n_tasks, (data.bs,), dtype=torch.long, device=model.device
)
return SequenceSample.from_default(
seqlens=seqlens,
ids=data.ids,
data=dict(packed_input_ids=packed_input_ids, task_ids=task_ids),
)
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)

View File

@ -1,40 +1,41 @@
# Copyright 2025 Ant Group Inc.
import collections
import copy
import ast
import asyncio
import dataclasses
import html
import itertools
import json
import os
import random
import re
import time
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional
from typing import Dict, List, Tuple
import colorama
import numpy as np
import requests
import torch
import torch.distributed as dist
import tqdm
import transformers
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.api.core.data_api import (
RL_TASKS,
MicroBatchSpec,
SequenceSample,
load_hf_tokenizer,
)
from realhf.base import constants
from realhf.base.constants import data_parallel_group, data_parallel_world_size
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.interface.math_parser import (
parse_lines_in_parallel as math_verify_local,
)
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
class VerifierException(Exception):
@ -53,7 +54,7 @@ def extract_python_code(text, min_length=20, strict_syntax=True):
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
ast.parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
@ -118,40 +119,28 @@ def check_with_elementtree(text):
return False, f"Error: XML格式错误, {str(e)}"
def reward_caculate(task, _answers, query_id_strs):
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
assert len(answers) == len(query_id_strs)
format_rewards = []
if task == "math":
format_rewards = math_verify_call(_answers, query_id_strs)
else:
codes = [extract_python_code(_answer) for _answer in _answers]
format_rewards = math_verify_call(answers, query_id_strs)
elif task == "code":
codes = [extract_python_code(_answer) for _answer in answers]
format_rewards = code_verify(codes, query_id_strs)
logger.info(
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
)
assert len(format_rewards) == len(answers), task
return format_rewards
def retokenize(
def retokenize_and_verify(
task,
tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
prompt_ids: List[List[int]],
seq_ids: List[List[int]],
query_ids: List[str],
check_xml_format=False,
do_eval=False,
):
input_ids = [
packed_input_ids[start:end]
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
]
prompt_ids = [
prompts[start:end]
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
]
seq_strs = tokenizer.batch_decode(
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
seq_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
prompt_strs = tokenizer.batch_decode(
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
@ -159,250 +148,131 @@ def retokenize(
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
logger.info(
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
)
answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
if do_eval == True:
_answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
format_rewards = dispatch_reward_calculation(task, answers, query_id_strs)
format_rewards = math_verify_call(_answers, query_id_strs)
if check_xml_format:
with ThreadPoolExecutor(max_workers=22) as executor:
futures = [
executor.submit(check_with_elementtree, answer_str)
for answer_str in _answers
]
# xml_rewards = []
for idx, future in enumerate(futures):
xml_reward, _ = future.result()
# xml_rewards.append(xml_reward)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
if check_xml_format:
for idx, answer in enumerate(answers):
xml_reward, _ = check_with_elementtree(answer)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(format_reward)
else:
for query_id_str in query_id_strs:
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(0)
format_rewards.append(0)
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
return format_rewards, prompt_strs, seq_strs
@dataclasses.dataclass
class PackedRewardInterface(model_api.ModelInterface):
enable_save: bool = False
class MultiTaskCPURewardInterface(model_api.ModelInterface):
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
rm_output_scaling: float = 1.0
rm_output_bias: float = 0.0
output_scaling: float = 1.0
output_bias: float = 0.0
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
max_sync_length: int = 2048
rw_type: str = "sparse"
task: str = "math" # math or countdown
check_xml_format: bool = False
post_process: str = "sigmoid"
group_size: int = 1
check_verifier_status: bool = False
def __post_init__(self):
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
logger.info(f"rm_output_bias: {self.rm_output_bias}")
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"max_sync_length: {self.max_sync_length}")
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
if constants.parallelism_rank() == 0:
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"rw_type: {self.rw_type}")
if (
constants.data_parallel_world_size()
< constants.parallelism_group_size()
):
logger.warning(
"There's no reason to use tensor and pipeline parallelism for CPU reward."
)
def inference(
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
xs = data.unpack()
dispatched = {}
dispatched_indices = {}
for task_idx, task_name in enumerate(RL_TASKS):
indices = (data.data["task_ids"] == task_idx).numpy().tolist()
if any(indices):
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
dispatched_indices[task_name] = indices
return dispatched, dispatched_indices
def _gather_tasks(
self, results: Dict, dispatched_indices: Dict, bs: int
) -> SequenceSample:
xs = [None for _ in range(bs)]
for task_name, indices in dispatched_indices.items():
xxs = results[task_name].unpack()
for i, xx in zip(indices, xxs):
xs[i] = xx
assert all(xs)
return SequenceSample.gather(xs)
def calculate_task_reward(
self,
model: model_api.Model,
data_: SequenceSample,
mb_spec,
) -> SequenceSample:
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
input_cu_seqlens = torch.nn.functional.pad(
input_seqlens.cumsum(0), (1, 0)
).int()
packed_prompts = data_.data["packed_prompts"]
prompts = []
prompt_seqlens = []
data: SequenceSample,
mb_spec: MicroBatchSpec,
task_type: str,
):
# mb_spec is disrespected here
packed_input_ids: torch.Tensor = data.data["packed_input_ids"]
input_seqlens = flat2d(data.seqlens["packed_input_ids"])
seq_ids = []
offset = 0
for x in data_.seqlens["packed_prompts"]:
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
offset += x[0]
prompt_seqlens.extend(x * self.group_size)
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
# non_packed_prompts = copy.deepcopy(prompts)
prompts = torch.cat(prompts)
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
prompt_cu_seqlens = torch.nn.functional.pad(
prompt_seqlens.cumsum(0), (1, 0)
).int()
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
retokenize(
self.task,
self.tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=self.check_xml_format,
do_eval=constants.is_last_pipe_stage(),
for slen in input_seqlens:
seq_ids.append(
packed_input_ids[offset : offset + slen].cpu().numpy().tolist()
)
offset += slen
assert offset == packed_input_ids.shape[0], (offset, packed_input_ids.shape)
prompt_input_ids = data.data["packed_prompts"]
prompt_len = flat2d(data.seqlens["packed_prompts"])
prompt_ids = []
offset = 0
for slen in prompt_len:
p = prompt_input_ids[offset : offset + slen].cpu().numpy().tolist()
prompt_ids += [p] * self.group_size
offset += slen
format_rewards, prompt_strs, seq_strs = retokenize_and_verify(
task_type,
self.tokenizer,
prompt_ids=prompt_ids,
seq_ids=seq_ids,
query_ids=[
str(data_id) for data_id in data.ids for _ in range(self.group_size)
],
check_xml_format=self.check_xml_format,
)
assert self.rw_type == "sparse"
dense_scores = torch.zeros_like(packed_input_ids).float()
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
scores[scores == 0] = -1
if len(scores) == 0:
return None
assert dense_scores.shape == packed_input_ids.shape
scores = (
scores.to(packed_input_ids.device) - self.output_bias
) * self.output_scaling
logger.info(f"Math reward logging info @v{model.version.global_step}")
logger.info(
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
if constants.is_last_pipe_stage():
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{model.version.global_step}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
model.inc_version()
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
if dense_scores.dtype != torch.float32:
dense_scores = dense_scores.to(torch.float32)
# NOTE: a place holder
dense_scores = packed_input_ids.new_zeros(dtype=torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
ids=data_.ids,
ids=data.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data_.seqlens["packed_input_ids"]
for x in data.seqlens["packed_input_ids"]
],
dense_rewards=data_.seqlens["packed_input_ids"],
dense_rewards=data.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
)
@ -410,13 +280,13 @@ class PackedRewardInterface(model_api.ModelInterface):
# record rewards for each piece of data
avg_scores = []
offset = 0
for i in range(data_.bs):
for i in range(data.bs):
score_lis = scores[
offset : offset + len(data_.seqlens["packed_input_ids"][i])
offset : offset + len(data.seqlens["packed_input_ids"][i])
]
avg_scores.append(score_lis.mean().item())
offset += len(data_.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
offset += len(data.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data.seqlens["packed_input_ids"])
res.metadata["scores"] = avg_scores
@ -425,20 +295,120 @@ class PackedRewardInterface(model_api.ModelInterface):
np.mean(avg_scores), device=constants.current_device()
)
dist.all_reduce(
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
avg_score, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
)
avg_score /= constants.parallelism_group_size()
avg_score /= constants.data_parallel_group()
avg_score = avg_score.item()
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
minimal_score = (-1 - self.output_bias) * self.output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise VerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
if not constants.is_last_pipe_stage():
return None
return res
def log_rewards_to_file(
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
):
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
task_type,
f"v{model.version.global_step}r{dist.get_rank()}.txt",
)
model_api.register_interface("reward", PackedRewardInterface)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
task_type,
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
def inference(
self,
model: model_api.Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
task_data, dispatch_indices = self._dispatch_tasks(data)
assert self.rw_type == "sparse"
def _task_func(func, task_type: str):
def _wrapped_func(*args, **kwargs):
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
except Exception as e:
raise asyncio.CancelledError(f"{task_type} task failed: {e}") from e
finally:
duration = time.perf_counter() - start_time
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
return result
return _wrapped_func
async def _run_tasks():
tasks = []
for task_type, d in task_data.items():
task_func = _task_func(self.calculate_task_reward, task_type)
task_args = (model, d, mb_spec, task_type)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append((task_type, task))
task_results = {}
results = await asyncio.gather(*tasks)
for task_type, result in results:
task_results[task_type] = result
return task_results
if constants.is_dp_head():
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
else:
final_result = None
model.inc_version()
return final_result
model_api.register_interface("reward", MultiTaskCPURewardInterface)

View File

@ -1,266 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import functools
import itertools
import time
from typing import Dict, Literal, Optional, Tuple
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils.functional import (
gather_packed_shifted_log_probs,
masked_normalization,
)
logger = logging.getLogger("RefRwInterface")
TASK_TYPE_REF: Literal["ref"] = "ref"
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
@dataclasses.dataclass
class RefRwInterface(model_api.ModelInterface):
n_minibatches: int = 4
# Use dict here to allow argument passing through commandline.
generation_config: Dict = dataclasses.field(default_factory=dict)
kl_ctl: float = 0.1
adv_norm: bool = True
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
max_reward_clip: float = 5.0
disable_value: bool = False
early_stop_kl: Optional[float] = None # e.g. 0.1
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
adaptive_kl_ctl: bool = False
adaptive_kl_target: Optional[float] = 6
adaptive_kl_horizon: Optional[float] = 10000
enable_save: bool = True
value_norm: bool = False
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
group_adv_norm: bool = False
mask_too_long: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
def __post_init__(self):
if self.adaptive_kl_ctl:
assert self.adaptive_kl_target is not None
assert self.adaptive_kl_horizon is not None
self.kl_adapter = ppo_functional.AdaptiveKLController(
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
)
else:
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
if self.value_norm:
from realhf.impl.model.modules import (
ExponentialRunningMeanStd,
MovingAverageRunningMeanStd,
)
if self.value_norm_type == "exp":
self.rms = ExponentialRunningMeanStd(
beta=self.value_norm_beta, epsilon=self.value_norm_eps
)
elif self.value_norm_type == "ma":
self.rms = MovingAverageRunningMeanStd()
else:
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
self.kl_ctl = None
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
if self.generation_size is not None:
assert self.generation_size >= self.group_size
else:
self.generation_size = self.group_size
self.gconfig.n = self.generation_size
def save(self, model: model_api.Model, save_dir: str):
if not self.enable_save:
return
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
module.save_to_hf(
tokenizer=model.tokenizer,
save_dir=save_dir,
)
def _dispatch_tasks(self, data):
math_data, code_data, rlhf_data, ref_data = data, data, data, data
return math_data, code_data, rlhf_data, ref_data
def _gather_tasks(self, data_map):
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
return data_map.get(TASK_TYPE_REF, None)
@torch.no_grad()
def ref_inference(
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
):
module = model.module
module.eval()
# This post_hook will gather log probabilities in mini-batches,
# reducing peak memory usage.
def calc_logprobs(logits, input_):
logits /= self.gconfig.temperature
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
logprobs = gather_packed_shifted_log_probs(
logits, cu_seqlens, input_.data["packed_input_ids"]
)
return logprobs
input_flattend = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
)
# add posthook to avoid storing full logits
logprobs = module.forward(
input_=input_flattend,
post_hook=calc_logprobs,
output_seqlens=[
[x - 1 for x in slens]
for slens in input_flattend.seqlens["packed_input_ids"]
],
mb_spec=mb_spec,
)
res = SequenceSample(
keys=["packed_ref_logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=model.module.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
seqlens=dict(
packed_ref_logprobs=[
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
]
),
)
return res
def inference(
self,
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample:
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
task_map = {
TASK_TYPE_REF: (self.ref_inference, ref_data),
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
}
def _task_func(func, task_type: str):
def _wrapped_func(*args, **kwargs):
start_time = time.perf_counter()
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
try:
result = func(*args, **kwargs)
except Exception as e:
logger.error(f"{task_type} ref_rw task failed: {e}")
finally:
duration = time.perf_counter() - start_time
logger.info(
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
)
return result
return _wrapped_func
async def _run_tasks() -> dict:
tasks = []
for task_type, (func, data) in task_map.items():
if not data:
continue
task_func = _task_func(func, task_type)
task_args = (model, data, mb_spec)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append((task_type, task))
results = {}
for task_type, task in tasks:
try:
results[task_type] = await task
except Exception as e:
logger.error(f"{task_type} task failed: {e}")
results[task_type] = None
return results
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results)
return final_result
# Mock methods for profiling only.
def _mock_inference(
self,
model: model_api.Model,
dataset_input: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
mconfig = module.config
packed_input_ids = torch.randint(
0,
mconfig.vocab_size,
(sum(seqlens),),
dtype=torch.long,
device=model.device,
)
return SequenceSample.from_default(
seqlens=seqlens,
ids=dataset_input.ids,
data=dict(packed_input_ids=packed_input_ids),
)
model_api.register_interface("ref_rw", RefRwInterface)

View File

@ -1,317 +0,0 @@
import functools
import gc
import json
import os
import pickle
import time
from typing import *
import numpy as np
import pynvml
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from torch.cuda import is_initialized
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
print(root_dir)
import sys
sys.path.insert(0, root_dir)
from realhf.api.core import data_api, dfg, model_api
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.network import find_free_port
from realhf.base.testing import (
_DEFAULT_EXPR_NAME,
_DEFAULT_TRIAL_NAME,
init_global_constants,
)
logger = logging.getLogger("test async ref-rew")
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
def loadJson():
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
def _mock_input(batch_size: int, seq_len):
vocab_size = 100
torch.manual_seed(1)
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
samples = loadJson()
id_list = list(samples.keys())
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
return data_api.SequenceSample.from_default(
seqlens=[seq_len for _ in range(seqs.shape[0])],
ids=[id_list[i] for i in range(seqs.shape[0])],
data=dict(
packed_input_ids=seqs.view(-1),
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
),
)
def funcion_call(
rpc_name: str,
rank: int,
world_size: int,
model_path: str,
model_family_name: str,
dp: int,
pp: int,
tp: int,
interface_type: dfg.ModelInterfaceType,
interface_impl: dfg.ModelInterfaceAbstraction,
batch_size: int,
prompt_len: int,
input_: data_api.SequenceSample | None,
port: int,
):
# assert not torch.cuda.is_initialized()
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(0)
assert world_size == (
dp * pp * tp
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
assert batch_size % dp == 0, (batch_size, dp)
# Initialize distributed environment.
model_name = ModelName("default", 0)
if not dist.is_initialized():
logger.info("Setting up distributed environment...")
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
init_method=f"tcp://localhost:{port}",
)
logger.info("Initialized distributed environment.")
init_global_constants(
num_dp=dp,
num_mp=tp,
num_pp=pp,
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
model_name=model_name,
max_prompt_len=prompt_len,
)
torch.cuda.set_device(0)
# NOTE: import here to avoid CUDA re-initialization
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
# Call a method like `config_from_llama` to get the config.
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
transformers.AutoConfig.from_pretrained(model_path)
)
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
mconfig.is_critic = is_critic
with constants.model_scope(model_name):
# Construct the model.
logger.info(f"Loading model from {model_path}...")
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
setattr(
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
)
module._instantiation_hooks.append(
lambda: getattr(module, f"from_{model_family_name}")(
load_dir=model_path,
init_critic_from_actor=is_critic,
)
)
add_helper_functions(module)
module.instantiate()
module.eval()
tokenizer = data_api.load_hf_tokenizer(model_path)
model = model_api.Model(
name=model_name,
module=module,
tokenizer=tokenizer,
device=module.device,
dtype=module.dtype,
)
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
from realhf.impl.model.backend.megatron import MegatronTrainBackend
backend = MegatronTrainBackend()
else:
from realhf.impl.model.backend.inference import PipelineInferenceBackend
backend = PipelineInferenceBackend()
logger.info("Running backend initialization...")
ft_spec = model_api.FinetuneSpec(
total_train_epochs=1,
dataset_size=128,
train_batch_size=128,
)
model = backend.initialize(model, ft_spec)
interface = model_api.make_interface(interface_impl)
if input_ is None:
input_ = _mock_input(batch_size, prompt_len)
input_ = input_.cuda()
mb_spec = model_api.MicroBatchSpec()
logger.info("Running interface computation...")
start = time.perf_counter_ns()
if interface_type == dfg.ModelInterfaceType.GENERATE:
res = interface.generate(model, input_, mb_spec)
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
res = interface.train_step(model, input_)
else:
res = interface.inference(model, input_, mb_spec)
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
if isinstance(res, data_api.SequenceSample):
res = res.cpu()
comsumed = time.perf_counter_ns() - start
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
return res
def run_function_call(
rpc_name: str,
model_path: str,
model_family_name: str,
batch_size: int,
prompt_len: int,
gen_len: int,
input_: data_api.SequenceSample | None,
) -> data_api.SequenceSample | None:
assert rpc_name in [
"actor_gen",
"actor_train",
"critic_inf",
"rew_inf",
"critic_train",
"ref_inf",
"ref_rw",
]
ref_rw_interface = dfg.ModelInterfaceAbstraction(
"ref_rw",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
"ppo_actor",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
rw_interface = dfg.ModelInterfaceAbstraction(
"paired_rw",
)
if rpc_name == "actor_gen":
interface_type = dfg.ModelInterfaceType.GENERATE
interface_impl = ppo_actor_interface
elif rpc_name == "actor_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_actor_interface
elif rpc_name == "critic_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_critic_interface
elif rpc_name == "ref_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_actor_interface
elif rpc_name == "ref_rw":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ref_rw_interface
elif rpc_name == "critic_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_critic_interface
else:
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = rw_interface
logger.info(f"Running RPC {rpc_name}...")
port = find_free_port()
res = funcion_call(
rank=0,
rpc_name=rpc_name,
world_size=1,
model_path=model_path,
model_family_name=model_family_name,
dp=1,
pp=1,
tp=1,
interface_type=interface_type,
interface_impl=interface_impl,
batch_size=batch_size,
prompt_len=prompt_len,
input_=input_,
port=port,
)
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if isinstance(res, data_api.SequenceSample):
return res
else:
logger.info(f"RPC {rpc_name} stats: {res}")
def main():
mp.set_start_method("spawn", force=True)
model_family_name = "qwen2"
batch_size = 16
prompt_len = 128
gen_len = 4096
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
for i in range(2):
ref_rw_res = run_function_call(
"ref_rw",
model_family_name=model_family_name,
model_path=model_path,
batch_size=batch_size,
prompt_len=prompt_len,
gen_len=gen_len,
input_=None,
)
if __name__ == "__main__":
main()