Merge updates from ant repository. (#34)

* Cherry-pick commit 90dfd575 "PullRequest: 84 [ADD..." 到当前分支

* Cherry-pick commit 15e787b7 "PullRequest: 44 eval..." 到当前分支

* Cherry-pick commit f255ef60 "PullRequest: 85 add ..." 到当前分支

* Cherry-pick commit c2b4006a "PullRequest: 86 Supp..." 到当前分支

* Cherry-pick commit fa6c0f3d "PullRequest: 87 upda..." 到当前分支

* Cherry-pick commit a9ff4af0 "PullRequest: 88 Bump..." 到当前分支

* Cherry-pick commit 763839aa "PullRequest: 89 Add ..." 到当前分支

* Cherry-pick commit 21e8064a "PullRequest: 90 Merg..." 到当前分支

* Cherry-pick commit 94e97670 "PullRequest: 92 Supp..." 到当前分支

* Cherry-pick commit 92710522 "PullRequest: 91 Supp..." 到当前分支

* Cherry-pick commit 95aa3f28 "PullRequest: 93 Supp..." 到当前分支

* Cherry-pick commit 62191f8f "PullRequest: 94 Add ..." 到当前分支

* Cherry-pick commit baa0249a "PullRequest: 95 Form..." 到当前分支

* Cherry-pick commit e32945f2 "PullRequest: 96 Chan..." 到当前分支

* Cherry-pick commit b59286e3 "PullRequest: 98 fix ..." 到当前分支

* Cherry-pick commit ca2ba43e "PullRequest: 97 Move..." 到当前分支

* Cherry-pick commit f941700b "PullRequest: 99 Refa..." 到当前分支

* Cherry-pick commit 95439e70 "PullRequest: 100 Add..." 到当前分支

* Cherry-pick commit f3ebd941 "PullRequest: 101 Add..." 到当前分支

* Cherry-pick commit ee4779ea "PullRequest: 103 [Fe..." 到当前分支

* Cherry-pick commit ce5e24ec "PullRequest: 104 [Fi..." 到当前分支

* Cherry-pick commit b385761f "PullRequest: 105 [Bu..." 到当前分支

* Cherry-pick commit 4c21fbb5 "PullRequest: 106 [Bu..." 到当前分支

* Cherry-pick commit 7f3f14e0 "PullRequest: 108 [Fi..." 到当前分支

* Cherry-pick commit 8de62701 "PullRequest: 107 [Fe..." 到当前分支

* Cherry-pick commit ea864b21 "PullRequest: 24 [Fea..." 到当前分支

* Cherry-pick commit 4a658db3 "PullRequest: 109 [Bu..." 到当前分支

* Cherry-pick commit aaa12bf1 "PullRequest: 110 [Bu..." 到当前分支

* Cherry-pick commit 6adb6d9f "PullRequest: 112 [Fi..." 到当前分支

* Cherry-pick commit 55556bc5 "PullRequest: 111 [Fe..." 到当前分支

* Cherry-pick commit bfe5ec94 "PullRequest: 114 pri..." 到当前分支

* Cherry-pick commit 44529c9b "PullRequest: 113 spl..." 到当前分支

* Cherry-pick commit b1cc73df "PullRequest: 116 [FI..." 到当前分支

* Cherry-pick commit eff598ce "PullRequest: 115 [Fi..." 到当前分支

* Cherry-pick commit f7149475 "PullRequest: 119 [Fi..." 到当前分支

* Cherry-pick commit f1017bfe "PullRequest: 121 add..." 到当前分支

* Cherry-pick commit 56f6de8d "PullRequest: 120 set..." 到当前分支

---------

Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
Co-authored-by: 温差 <xushusheng.xss@antgroup.com>
Co-authored-by: 郭唯 <kira.gw@antgroup.com>
Co-authored-by: 博惟 <bowei.fw@antgroup.com>
Co-authored-by: 君末 <meijun.mei@antgroup.com>
This commit is contained in:
nuzant 2025-04-27 11:09:25 +08:00 committed by GitHub
parent 42afcdb53a
commit ffc52a1520
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
94 changed files with 7132 additions and 1609 deletions

View File

@ -21,8 +21,6 @@ ENV NVTE_WITH_USERBUFFERS=1 NVTE_FRAMEWORK=pytorch MAX_JOBS=8 MPI_HOME=/usr/loca
ENV PATH="${PATH}:/opt/hpcx/ompi/bin:/opt/hpcx/ucx/bin"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/opt/hpcx/ompi/lib:/opt/hpcx/ucx/lib/"
RUN pip3 install deepspeed==0.14.0 megatron==0.6.0
COPY ./requirements.txt /requirements.txt
RUN pip3 install -r /requirements.txt && rm /requirements.txt

View File

@ -0,0 +1,47 @@
import argparse
import hashlib
import json
from glob import glob
import numpy as np
from tqdm import tqdm
def get_hash(text):
text_bytes = text.encode("utf-8") # 将文本转换为字节串
md5_hash = hashlib.md5(text_bytes).hexdigest()
return md5_hash
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--log_path", type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
all_results = {}
for fname in glob(f"{args.log_path}/*.jsonl"):
with open(fname) as f:
cur_results = [json.loads(x) for x in f]
for x in cur_results:
# query_id = x['query_id'].split("@")[0]
if "query_id" in x:
query_id = x["query_id"].split("@")[0]
else:
query_id = get_hash(x["prompt"])
if query_id not in all_results:
all_results[query_id] = []
all_results[query_id].append(x["reward"] > 0)
all_acc = []
for query_id, results in sorted(all_results.items(), key=lambda x: x[0]):
print(query_id, len(results), np.mean(results))
all_acc.append(sum(results) / len(results))
print(len(all_acc))
print(f"Mean accuracy: {np.mean(all_acc)}")

View File

@ -3,6 +3,7 @@ trial_name: 512x16
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d64p1m1+d32p2m1'
@ -27,6 +28,7 @@ actor:
sglang:
mem_fraction_static: 0.8
triton_attention_num_kv_splits: 16
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x16
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d16p1m1+d8p2m1'
@ -27,6 +28,7 @@ actor:
sglang:
mem_fraction_static: 0.8
triton_attention_num_kv_splits: 16
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x16
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d4p1m1+d2p2m1'
@ -27,6 +28,7 @@ actor:
sglang:
mem_fraction_static: 0.8
triton_attention_num_kv_splits: 16
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x32
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d8m8p1+d4p4m4'
@ -30,6 +31,7 @@ actor:
triton_attention_num_kv_splits: 16
max_running_requests: 128
context_length: 29696
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x16
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d64p1m1+d32p2m1'
@ -27,6 +28,7 @@ actor:
sglang:
mem_fraction_static: 0.8
triton_attention_num_kv_splits: 16
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x16
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d16p1m1+d8p2m1'
@ -27,6 +28,7 @@ actor:
sglang:
mem_fraction_static: 0.8
triton_attention_num_kv_splits: 16
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -3,6 +3,7 @@ trial_name: 512x64
mode: ray
wandb:
mode: disabled
metric_discovery_port: 17997
recover_mode: auto
recover_retries: 10
allocation_mode: 'sglang.d64p1m1+d32p2m1'
@ -30,6 +31,7 @@ actor:
triton_attention_num_kv_splits: 16
max_running_requests: 128
context_length: 18432
enable_metrics: True
critic:
type:
_class: qwen2

View File

@ -0,0 +1,135 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
"""
Dataset Toolkit - Process and validate code/math datasets with flexible input support
"""
import argparse
import json
import logging
import random
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
# Configure console logging
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict]:
"""Load JSONL file with validation"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def save_file(output_path: str, processed_data: list):
with open(output_path, "w") as f:
for item in processed_data:
f.write(json.dumps(item) + "\n")
def process_math_data(file_path: str) -> List[Dict]:
"""Process math dataset from JSON/JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for index, item in enumerate(raw_data):
processed.append(
{
"task": "math",
"query_id": str(item.get("query_id", f"math-{index}")),
"prompt": item["context"],
"solutions": [item["groundtruth"]],
}
)
return processed
def process_code_data(file_path: str) -> List[Dict]:
"""Process code dataset from JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for item in raw_data:
# Field extraction and transformation
input_output = item["code_test_cases"]
time_limit = item["meta"]["time_limit"]
seconds = time_limit.get("seconds", 0) + time_limit.get("nanos", 0) / 1e9
memory = item["meta"]["memory_limit_bytes"] / (1024 * 1024)
processed.append(
{
"task": "code",
"query_id": str(item["id"]),
"prompt": item["context"],
"input_output": json.dumps(
{
"inputs": [io.get("input") for io in input_output],
"outputs": [io.get("output") for io in input_output],
"fn_name": item.get("metadata", {}).get("fn_name", ""),
"remote": False,
}
),
"solutions": [item["groundtruth"]],
"language": "PYTHON",
"timeout": seconds,
"memory": memory,
}
)
return processed
def main():
parser = argparse.ArgumentParser(
description="Dataset Toolkit: Process and validate STEM datasets",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--code", help="Path to code dataset (JSONL)")
parser.add_argument("--math", help="Path to math dataset (JSONL)")
parser.add_argument("--output", help="Output file path (JSONL)")
args = parser.parse_args()
if not args.output:
logger.error("Output file required in process mode")
return
processed_data = []
stats = defaultdict(int)
if args.code:
code_data = process_code_data(args.code)
logger.info(f"Loaded {len(code_data)} code items")
processed_data.extend(code_data)
stats["code"] = len(code_data)
if args.math:
math_data = process_math_data(args.math)
logger.info(f"Loaded {len(math_data)} math items")
processed_data.extend(math_data)
stats["math"] = len(math_data)
random.shuffle(processed_data)
save_file(args.output, processed_data)
logger.info("\nProcessing Complete:")
logger.info(f"Total items: {len(processed_data)}")
logger.info(f"Code items: {stats['code']}")
logger.info(f"Math items: {stats['math']}")
if __name__ == "__main__":
main()

View File

@ -46,16 +46,24 @@ def process_code_data(file_path: str) -> List[Dict]:
"task": "code",
"query_id": str(item["id"]),
"prompt": item["question"],
"solutions": item.get("solutions", []),
"input_output": json.dumps(
{
"inputs": input_output.get("inputs", []),
"outputs": input_output.get("outputs", []),
"fn_name": item.get("metadata", {}).get("fn_name", ""),
"remote": False,
}
),
"language": item.get("language", "PYTHON"),
}
)
case_size = sys.getsizeof(processed[-1]["input_output"])
assert (
case_size < 500 * 1024
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
return processed
@ -73,7 +81,7 @@ def process_math_data(file_path: str) -> List[Dict]:
"task": "math",
"query_id": str(item["query_id"]),
"prompt": item["prompt"],
"solutions": item["solutions"],
"solutions": item.get("solutions", []),
}
)

View File

@ -3,15 +3,20 @@ import logging
import os
import random
import time
import traceback
from enum import Enum
from statistics import median
from typing import Any, Dict
import aiohttp
from functioncall.base import logging
try:
from realhf.base import constants, logging
except Exception:
import logging
logger = logging.getLogger("Functioncall")
constants = None
logger = logging.getLogger("function call")
FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
"FUNCTIONCALL_SERVICE_DOMAIN",
@ -19,32 +24,75 @@ FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
)
def check_payload(payload):
if not payload:
return False, {
"uid": payload.get("uid", ""),
"success": False,
"results": [
{
"success": False,
"reason": "Empty payload",
"errorType": "UnknownError",
}
],
}
if not payload.get("code"):
return False, {
"uid": payload.get("uid", ""),
"success": False,
"results": [
{"success": False, "reason": "Empty code", "errorType": "UnknownError"}
],
}
return True, {}
class Language(Enum):
PYTHON = 0
JAVA = 1
CPP = 2
C = 3
MATH = 4
SQL = 5
GO = 6
NODEJS = 7
CSHARP = 8
TYPESCRIPT = 9
JAVASCRIPT = 10
def __str__(self):
return f"{self.name.lower()}"
def calculate_percentile(elapsed_times, percentile):
sorted_times = sorted(elapsed_times)
index = int(len(sorted_times) * (percentile / 100))
return sorted_times[min(index, len(sorted_times) - 1)]
def has_system_error(response_json):
for result in response_json.get("results", []):
if result.get("errorType", "") == "SystemError":
return True, result
return False, None
async def async_invoke_function(
session: aiohttp.ClientSession,
function_name: str,
url: str,
timeout: aiohttp.ClientTimeout,
payload: Dict[str, Any] = None,
max_retries: int = 100,
initial_retry_interval: float = 0.5,
max_retry_interval: float = 10.0,
):
if payload is None:
payload = {}
url = f"{FUNCTIONCALL_SERVICE_DOMAIN}/hapis/faas.hcs.io/v1/functions/{function_name}/invoke"
params = {"invocationType": "RequestResponse"}
retries = 0
while retries < max_retries:
try:
async with session.post(
url,
params=params,
json=payload,
timeout=timeout,
) as response:
@ -55,94 +103,132 @@ async def async_invoke_function(
)
try:
result = await response.json()
return result, response.headers
response_json = await response.json()
exist, err_info = has_system_error(response_json)
if exist:
raise Exception(
f'SystemError detected, uid: {response_json.get("uid")}, err: {err_info}'
)
return response_json
except aiohttp.ContentTypeError as e:
raise Exception("Invalid JSON response") from e
except asyncio.TimeoutError as e:
logger.warning(
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}"
f'Request timeout after {timeout}s, uid: {payload.get("uid")}, URL: {url}'
)
break
return {
"uid": payload.get("uid", ""),
"success": False,
"results": [
{
"success": False,
"reason": "Function call timed out.",
"errorType": "UnknownError",
}
],
}
except Exception as e:
logger.error(
f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}"
f"Async invocation failed on attempt {retries + 1}:{str(e)}, uid: {payload.get('uid')}, URL: {url}"
)
retries += 1
if retries > max_retries:
return None, None
return {
"uid": payload.get("uid", ""),
"success": False,
"results": [
{
"success": False,
"reason": "Function call exceed max retries.",
"errorType": "UnknownError",
}
],
}
# 指数退避 + 随机抖动
sleep_time = min(
initial_retry_interval * (2**retries) + random.uniform(0, 0.1),
initial_retry_interval * (2**retries) + random.uniform(0, 5),
max_retry_interval,
)
await asyncio.sleep(sleep_time)
async def batch_function_call_async(
payload_list, function_name, timeout, concurrency=1500
):
connector = aiohttp.TCPConnector(limit=0)
async def batch_function_call_async(payload_list, url, timeout, concurrency=1500):
connector = aiohttp.TCPConnector(
limit=concurrency,
ttl_dns_cache=300, # DNS cache
keepalive_timeout=80, # keepalive_timeout need to be smaller than the middle link idle-timeout
)
async with aiohttp.ClientSession(connector=connector) as session:
semaphore = asyncio.Semaphore(concurrency)
async def limited_task(payload):
if not payload:
return None
ok, err_rsp = check_payload(payload)
if not ok:
return err_rsp, 0
async with semaphore:
st = time.monotonic()
result = await async_invoke_function(
session, function_name, timeout, payload
)
result = await async_invoke_function(session, url, timeout, payload)
return result, time.monotonic() - st
tasks = [limited_task(payload) for payload in payload_list]
results = await asyncio.gather(*tasks, return_exceptions=True)
results = results if results else []
data_list = []
elapsed_times = []
max_elapsed = -1
max_elapsed_header = None
for (data, header), elapsed in results:
max_elapsed_uid = ""
for data, elapsed in results:
if elapsed > max_elapsed:
max_elapsed = elapsed
max_elapsed_header = header
max_elapsed_uid = data.get("uid")
data_list.append(data)
elapsed_times.append(elapsed)
# logger.debug(f"functioncall took {elapsed:.4f} seconds, header: {header}.)")
p50 = median(elapsed_times)
p90 = calculate_percentile(elapsed_times, 90)
p99 = calculate_percentile(elapsed_times, 99)
logger.info(
f"Longest functioncall {function_name} took {max_elapsed:.4f} seconds, header: {max_elapsed_header}, timeout: {timeout}, p50: {p50}, p90: {p90}, p99: {p99}"
f"Longest functioncall took {max_elapsed:.4f} seconds, timeout: {timeout}, uid: {max_elapsed_uid}, Active connections: {len(connector._conns)}, p50: {p50}, p90: {p90}, p99: {p99}"
)
return data_list
def get_function_name(runtime_type):
if runtime_type == "python_code":
return "realhf_code_verify"
if runtime_type == "python_live_code_bench":
return "python_live_code_bench"
elif runtime_type == "python_math":
return "python_math"
return "empty_code"
def get_runtime_name(runtime, language):
if runtime:
return runtime
else:
return str(language).lower() + "-default"
def batch_function_call(payload_list, runtime_type, timeout):
def caculate_concurrency():
# use 5000 cpu cores for one exp by default
concurrency_for_one_exp = 5000
try:
dp = constants.parallelism_group_size()
except Exception as e:
dp = 16
return concurrency_for_one_exp // dp
def batch_function_call(payload_list, task_type, timeout):
start_time = time.time()
function_name = get_function_name(runtime_type)
url = f"{FUNCTIONCALL_SERVICE_DOMAIN}/apis/functioncalls"
concurrency = caculate_concurrency()
logger.info(
f"Batch function call start, task type: {task_type}, request count: {len(payload_list)}, time: {time.ctime(start_time)} ms, concurrency: {concurrency}"
)
result = asyncio.run(
batch_function_call_async(payload_list, function_name, timeout)
batch_function_call_async(payload_list, url, timeout, concurrency=concurrency)
)
execution_time = time.time() - start_time
logger.info(
f"Batch function call done, runtime type: {runtime_type}, batch size: {len(payload_list)}, cost: {execution_time * 1000:.0f} ms"
f"Batch function call done, task type: {task_type}, batch size: {len(payload_list)}, cost: {execution_time * 1000:.0f} ms"
)
return result

View File

@ -1,30 +0,0 @@
import logging
from typing import Optional
import colorlog
LOG_FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s"
DATE_FORMAT = "%Y%m%d-%H:%M:%S"
LOGLEVEL = logging.DEBUG
formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s" + LOG_FORMAT,
datefmt=DATE_FORMAT,
log_colors={
"DEBUG": "blue",
"INFO": "light_purple",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_white,bg_red",
},
)
handler = logging.StreamHandler()
handler.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logging.basicConfig(level=LOGLEVEL, handlers=[handler])
def getLogger(name: Optional[str] = None):
return logging.getLogger(name)

View File

@ -0,0 +1,47 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import json
from datetime import datetime
try:
from realhf.base import constants, logging
logger = logging.getLogger("function call")
except Exception:
import logging
constants = None
logger = logging.getLogger("function call")
logger.setLevel(logging.DEBUG)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
console.setFormatter(formatter)
logger.addHandler(console)
def construct_uid(query_id: str, start_idx: int, end_idx: int):
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
try:
trial_name = f"{constants.experiment_name()}-{constants.trial_name()}"
except Exception as e:
trial_name = "test"
uid = f"[{timestamp}-{trial_name}]-{query_id}-[{start_idx}-{end_idx}]"
return uid
def load_jsonl(file_path: str):
"""Load JSONL file with validation"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise

View File

@ -105,7 +105,7 @@ def run_test(sample, test=None, debug=False, timeout=6):
which_type = CODE_TYPE.call_based
if in_outs:
if in_outs.get("fn_name") is None:
if in_outs.get("fn_name", "") == "":
which_type = CODE_TYPE.standard_input # Standard input
method_name = None
else:
@ -116,8 +116,7 @@ def run_test(sample, test=None, debug=False, timeout=6):
print(f"loaded input_output = {datetime.now().time()}")
if test is None:
assert False, "should not happen: test code is none"
return in_outs, {"error": "no test code provided"}
return [False] * len(in_outs["inputs"]), {"error": "no test code provided"}
elif test is not None:
results = []
sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n"
@ -760,7 +759,7 @@ def reliability_guard(maximum_memory_bytes=None):
subprocess.Popen = None # type: ignore
__builtins__["help"] = None
# __builtins__["help"] = None
import sys
@ -769,3 +768,20 @@ def reliability_guard(maximum_memory_bytes=None):
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--tmp_id", type=str, required=True)
args = parser.parse_args()
all_input_data = []
with open(f"/tmp/{args.tmp_id}-input.json", "r") as temp_file:
input_data = json.load(temp_file)
result, info = run_test(**input_data)
saved_result = {"result": result, "info": info}
with open(f"/tmp/{args.tmp_id}-output.json", "w", encoding="utf-8") as temp_file:
json.dump(saved_result, temp_file)

View File

@ -1,16 +1,22 @@
import concurrent.futures
import json
import multiprocessing
import os
import signal
import subprocess
import sys
import time
import traceback
import uuid
from collections import defaultdict
from io import StringIO
from typing import Dict, List
from functioncall.base import logging
from functioncall.code.function.testing_util import run_test
from functioncall.base.utils import load_jsonl, logger
from realhf.base import logging
logger = logging.getLogger("Functioncall")
SINGLE_CASE_EXEC_TIMEOUT = 6
logger = logging.getLogger("function call")
def capture_stdout(code):
@ -27,121 +33,116 @@ def capture_stdout(code):
return fake_stdout.getvalue()
def _temp_run(problem, generation, debug, result):
def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
tmp_id = str(uuid.uuid4())
input_data = {
"sample": problem,
"test": generation,
"debug": debug,
"timeout": timeout,
}
with open(f"/tmp/{tmp_id}-input.json", "w") as temp_file:
json.dump(input_data, temp_file)
start_time = time.time()
venv_python = "python3"
pro = subprocess.Popen(
" ".join(
[
venv_python,
"functioncall/code/function/testing_util.py",
"--tmp_id",
tmp_id,
]
),
shell=True,
preexec_fn=os.setsid,
stdout=subprocess.DEVNULL,
)
try:
if debug:
logger.debug(f"Running test for problem: {problem}")
r = run_test(sample=problem, test=generation, debug=debug)
result.append(r)
if debug:
logger.debug(f"Test completed with result: {result}")
pro.wait(600)
except Exception as e:
pass
try:
os.killpg(os.getpgid(pro.pid), signal.SIGTERM)
except ProcessLookupError:
pass
result = {"result": [False], "info": {}}
try:
with open(f"/tmp/{tmp_id}-output.json", "r") as f:
result = json.load(f)
except FileNotFoundError as e:
logger.warning(
f"Error in _temp_run: {e}\n"
f"traceback: {''.join(traceback.format_exception(*sys.exc_info()))}\n"
f"problem:{problem}"
f"{problem['query_id']}: Failed to parse generated answers. FileNotFoundError. Set 0 reward."
)
except Exception as e:
logger.warning(
f"{problem['query_id']}: Failed to parse generated answers. {e}. Set 0 reward."
)
finally:
if os.path.exists(f"/tmp/{tmp_id}-input.json"):
os.remove(f"/tmp/{tmp_id}-input.json")
if os.path.exists(f"/tmp/{tmp_id}-output.json"):
os.remove(f"/tmp/{tmp_id}-output.json")
execution_time = time.time() - start_time
logger.info(
f'[_temp_run] query_id: {problem["problem_id"]}, start_time: {str(start_time)}, Time elapsed: {execution_time * 1000:.0f} ms'
f'[call_verify] query_id: {problem["query_id"]}, start_time: {str(start_time)}, Time elapsed: {execution_time * 1000:.0f} ms'
)
def check_correctness(problem, generation, timeout, debug=False):
"""Check correctness of code generation with a global timeout.
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""
if debug:
# FIXME: error variable "problem" is not defined
result = capture_stdout(
"from functioncall.code.function.testing_util import run_test\n"
+ "run_test(sample=problem, test=generation, debug=debug)"
)
return result[0], result[1]
start_time = time.time()
manager = multiprocessing.Manager()
result = manager.list()
p = multiprocessing.Process(
target=_temp_run, args=(problem, generation, debug, result)
)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
if debug:
logger.debug(f"Process is still alive. Killing the process.")
p.kill()
if not result:
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
# so we use 21=the average number of tests for a smaple in the test split instead
avg_number_tests = 21
result = [[-1 for _ in range(avg_number_tests)], {}]
if debug:
logger.debug(f"Global timeout occurred, returning default result.")
if debug:
logger.debug(f"Final result: {result}")
execution_time = time.time() - start_time
logger.info(
f'[check_correctness] query_id: {problem["problem_id"]}, start_time: {str(start_time)}, Time elapsed: {execution_time * 1000:.0f} ms'
)
return result[0]
return result["result"], result["info"]
def code_verify(id2info, generateds, query_ids, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
result = []
final_results = []
infer_args = []
for query_id, generated, problem in zip(query_ids, generateds, problems):
logger.debug(f"run_batch_code, query_id: {query_id}")
try:
curr_res, metadata = check_correctness(
problem=problem, generation=generated, timeout=6000, debug=debug
)
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
if any(x != True for x in curr_res):
logger.debug(f"id:{query_id}, Results were not all True: {metadata}")
result.append(0)
else:
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
result.append(1)
run_results = []
num_process = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
run_results = executor.map(call_verify, *zip(*infer_args))
except Exception as e:
exc_info = sys.exc_info()
logger.error(
f"test framework exception = {repr(e)}{e}\n{traceback.format_exception(*exc_info)}"
)
result.append(0)
for run_result in run_results:
curr_res, metadata = run_result
if any(x != True for x in curr_res):
final_results.append(0)
else:
final_results.append(1)
return result
return final_results
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
id2info = defaultdict(dict)
for item in data_list:
id2info[item["query_id"]] = item
id2info = {}
solutions = []
query_ids = []
for i in range(10):
problem = code_data[i]
problem["problem_id"] = problem["id"]
id2info[problem["problem_id"]] = problem
solutions.append(json.loads(problem["solutions"])[0])
query_ids.append(problem["id"])
def create_test_params(count=10):
query_ids = []
generateds = []
cnt = 0
result = code_verify(
id2info,
solutions,
query_ids,
debug=False,
)
for d in data_list:
if cnt >= count:
break
if not d["solutions"] or d["query_id"] not in id2info:
continue
query_ids.append(d["query_id"])
generateds.extend(d["solutions"])
cnt += 1
return generateds, query_ids
generateds, query_ids = create_test_params(100)
scale = 1
print(f"generateds:, query_ids:{query_ids}")
result = code_verify(id2info, generateds * scale, query_ids * scale)
print(result)

View File

@ -2,55 +2,115 @@ import json
import os
import random
from collections import defaultdict
from datetime import datetime
from functioncall.base import logging
from functioncall.base.call import batch_function_call
from functioncall.base.call import Language, batch_function_call, get_runtime_name
from functioncall.base.utils import construct_uid, load_jsonl, logger
logger = logging.getLogger("Functioncall")
SINGLE_CASE_EXEC_TIMEOUT = 6
TEST_CASE_BATCH_SIZE = 1
FUNCTIONCALL_TIMEOUT = 1000
def round_up_memory(memory):
if memory <= 0:
return 0
rounded = ((memory + 255) // 256) * 256
return 0 if rounded > 1024 else rounded
def construct_testcases(
inputs: list, outputs: list, index: tuple, remote: bool = False, is_ut: bool = False
) -> dict:
result = []
if is_ut:
return result
for i in range(*index):
input_, output_ = inputs[i].strip(), outputs[i].strip()
if not remote:
result.append({"input": input_, "expectedOutput": output_})
continue
oss_basepath = "http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/"
input_url = (
input_ if input_.startswith("http") else os.path.join(oss_basepath, input_)
)
output_url = (
output_
if output_.startswith("http")
else os.path.join(oss_basepath, output_)
)
result.append({"input": input_url, "expectedOutput": output_url})
return result
def load_problems_with_testcase_batch(
id2info, query_ids, debug=False, test_case_batch_size=None
id2info, query_ids, generateds, timeout_for_testcase, test_case_batch_size
):
problem_map = defaultdict(list)
problem_list = []
for idx, query_id in enumerate(query_ids):
problem = id2info[query_id]
# parse one problem
language = problem.get("language", "PYTHON").upper()
timeout = min(
100, max(0.1, float(problem.get("timeout", timeout_for_testcase)) * 1.5)
) # [0.1, 100] s
memory = round_up_memory(problem.get("memory", 0))
input_output = json.loads(problem["input_output"])
fn_name = input_output.get("fn_name", "")
remote = input_output.get("remote", False)
inputs = input_output.get("inputs", [])
outputs = input_output.get("outputs", [])
assert len(inputs) == len(
outputs
), f"Inputs({len(inputs)}) and outputs({len(outputs)}) mismatch for {query_id}"
other_io_fields = {
k: v for k, v in input_output.items() if k not in ["inputs", "outputs"]
}
# create batches for testcases
if not test_case_batch_size or test_case_batch_size <= 0:
test_case_batch_size = len(inputs)
assert (
language in Language.__members__
), f"{language} is not a valid Language name"
for batch_idx in range(0, len(inputs), test_case_batch_size):
batch_io = {
**other_io_fields,
"inputs": inputs[batch_idx : batch_idx + test_case_batch_size],
"outputs": outputs[batch_idx : batch_idx + test_case_batch_size],
}
is_ut = len(inputs) == 0
# isFastFail means the function call returns immediately as soon as any testcase fails.
isFastFail = True
# create batches for testcases
case_size = 1 if is_ut else len(inputs)
test_case_batch_size = min(max(1, test_case_batch_size), case_size)
for batch_idx in range(0, case_size, test_case_batch_size):
end_idx = min(case_size, batch_idx + test_case_batch_size)
testcases = construct_testcases(
inputs, outputs, (batch_idx, end_idx), remote, is_ut
)
sub_problem = {
"problem_id": query_id,
"input_output": json.dumps(batch_io),
"batche_index": batch_idx,
"uid": construct_uid(query_id, batch_idx, end_idx),
"language": language,
"runtime": get_runtime_name("", language),
"code": generateds[idx],
"entryFunction": fn_name,
"isFastFail": isFastFail,
"isRemote": remote,
"testcases": testcases,
"timeout": timeout,
"memory": memory,
"query_index": idx,
}
if debug:
sub_problem["solutions"] = problem.get("solutions", [])
problem_map[query_id].append(sub_problem)
problem_list.append(sub_problem)
return problem_map
return problem_list
def code_verify(
id2info, generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
id2info,
generateds,
query_ids,
timeout=FUNCTIONCALL_TIMEOUT,
timeout_for_testcase=SINGLE_CASE_EXEC_TIMEOUT,
test_case_batch_size=TEST_CASE_BATCH_SIZE,
):
assert len(generateds) == len(query_ids), (
len(generateds),
@ -58,71 +118,64 @@ def code_verify(
)
payload_list = []
global_problems = load_problems_with_testcase_batch(
payload_list = load_problems_with_testcase_batch(
id2info,
query_ids,
debug=True,
test_case_batch_size=20,
generateds,
timeout_for_testcase,
test_case_batch_size,
)
for idx, query_id in enumerate(query_ids):
problems = global_problems[query_id]
for problem in problems:
payload_list.append(
{
"problem": problem,
"code": generateds[idx],
"debug": debug,
"timeout": timeout_for_testcase,
"query_index": idx,
}
)
logger.debug(
f"code_verify, payload_list size: {len(payload_list)}, query size: {len(query_ids)}, query_id_0: {query_ids[0]}"
logger.info(
f"code_verify start, request count: {len(payload_list)}, query size: {len(query_ids)}, query_id_0: {query_ids[0]}"
)
rsp_list = batch_function_call(payload_list, "python_code", timeout=timeout)
rsp_list = batch_function_call(payload_list, "code", timeout)
results = [1] * len(query_ids)
results = [1] * len(query_ids) if len(rsp_list) else [0] * len(query_ids)
for idx, rsp in enumerate(rsp_list):
query_index = payload_list[idx]["query_index"]
query_id = query_ids[query_index]
value = 0
if rsp and "result" in rsp and not any(x != True for x in rsp["result"]):
if rsp and rsp.get("success", False):
value = 1
else:
logger.debug(
f"Functioncall code verify not passed, query index: {query_index}, query id: {query_id}, results: {rsp}"
f'Functioncall code verify not passed, uid: {rsp.get("uid")}, query id: {query_id}, results: {rsp}'
)
results[query_index] = results[query_index] and value
logger.info(
f"code_verify finished, request count: {len(payload_list)}, query count: {len(query_ids)}, result count: {len(results)}"
)
return results
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
id2info = {}
data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
id2info = defaultdict(dict)
for item in data_list:
id2info[item["query_id"]] = item
def create_test_params(count=10):
global id2info
query_ids = []
generateds = []
cnt = 0
while cnt < count:
d = random.choice(code_data)
if not d["solutions"]:
for d in data_list:
if cnt >= count:
break
if d["query_id"] not in id2info:
continue
id2info[d["id"]] = d
query_ids.append(d["id"])
generateds.append(d["solutions"][0])
query_ids.append(d["query_id"])
generateds.extend(d["solutions"])
cnt += 1
return generateds, query_ids
generateds, query_ids = create_test_params(100)
result = code_verify(id2info, generateds, query_ids, True)
scale = 1
print(f"generateds:, query_ids:{query_ids}")
result = code_verify(id2info, generateds * scale, query_ids * scale)
print(result)

View File

@ -1,12 +1,12 @@
import json
import os
import time
from collections import defaultdict
from datetime import datetime
from typing import List
from functioncall.base import logging
from functioncall.base.call import batch_function_call
logger = logging.getLogger("Functioncall")
from functioncall.base.call import Language, batch_function_call, get_runtime_name
from functioncall.base.utils import construct_uid, logger
def math_verify(
@ -32,41 +32,54 @@ def math_verify(
start_time = time.time()
batch_args_list = []
for i in range(0, len(parameters), batch_size):
answers, solutions, indices = zip(*parameters[i : i + batch_size])
end_idx = min(i + batch_size, len(parameters))
answers, solutions, indices = zip(*parameters[i:end_idx])
batch_args = {
"answers": list(answers),
"solutions": list(solutions),
"query_ids": [query_ids[i] for i in indices],
}
batch_args_list.append(batch_args)
sub_problem = {
"uid": construct_uid("math", i, end_idx),
"language": str(Language.MATH).upper(),
"runtime": get_runtime_name(None, str(Language.MATH)),
"code": 'print("hello math!")',
"testcases": [{}] * (end_idx - i), # required filed
"timeout": 5,
"isFastFail": True,
"extraInfo": batch_args,
}
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
batch_args_list.append(sub_problem)
results_batch = batch_function_call(batch_args_list, "math", timeout)
labels = [0] * len(query_ids)
# Map results back to original indices
index = 0
for batch_idx, results in enumerate(results_batch):
if not isinstance(results, list) or len(results) == 0:
index += len(batch_args_list[batch_idx]["answers"])
# check result format
if not (
isinstance(results, dict)
and "results" in results
and isinstance(results["results"], list)
and results["results"]
and all(isinstance(item, dict) for item in results["results"])
):
index += len(batch_args_list[batch_idx]["extraInfo"]["query_ids"])
logger.warning(
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_indices[index]}, params: {batch_args_list[batch_idx]['answers']}."
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_indices[index]}, params: {batch_args_list[batch_idx]}."
)
continue
for result in results:
for result in results["results"]:
query_index = query_indices[index]
if (
isinstance(result, list)
and len(result) > 0
and (isinstance(result[0], int) and result[0] in [0, 1])
):
labels[query_index] = result[0] or labels[query_index]
else:
logger.warning(
f"Invalid functioncall math result: {result}, index:{index}, qeury_id: {query_ids[query_index]}."
)
# set label as 1 if any of the solutions matches the answer
labels[query_index] = (
int(result.get("success", False)) or labels[query_index]
)
index += 1
logger.info(
@ -77,7 +90,7 @@ def math_verify(
if __name__ == "__main__":
sample = {
"answers": ["-\\frac{2}{3}"],
"answers": ["\\boxed{-\\frac{2}{3}}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
@ -85,10 +98,11 @@ if __name__ == "__main__":
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
scale = 50
start_time = time.time()
result = math_verify(
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
sample["answers"] * scale,
["fe11b471-1aa9-4867-958f-a0a811c85f92"] * scale,
)
print(result)

View File

@ -0,0 +1,298 @@
import copy
import json
import logging
import math
import os
import pickle
import sys
import time
from collections import defaultdict
from datetime import datetime
from multiprocessing import Manager, Pool, cpu_count, shared_memory
from typing import Any, Dict, List
import numpy as np
from functioncall.code.verify import code_verify
logger = logging.getLogger("function call")
def parallel_code_verify(
id2info: Dict[str, Any],
generateds: List[str],
query_ids: List[str],
test_case_batch_size: int,
num_processes: int = min(cpu_count(), 128),
) -> List[Any]:
shm = None
pool = None
try:
# set id2info in shared memory
serialized_dict = pickle.dumps(id2info)
buffer = np.frombuffer(serialized_dict, dtype=np.uint8)
shm = shared_memory.SharedMemory(create=True, size=buffer.nbytes)
buffer_shared = np.ndarray(buffer.shape, dtype=buffer.dtype, buffer=shm.buf)
buffer_shared[:] = buffer[:]
shared_dict = (shm.name, buffer.shape, buffer.dtype)
chunk_size = math.ceil(len(generateds) / num_processes)
chunks = [
(
i,
shared_dict,
generateds[i : i + chunk_size],
query_ids[i : i + chunk_size],
test_case_batch_size,
)
for i in range(0, len(generateds), chunk_size)
]
print(
f"parallel_code_verify start generateds_size: {len(generateds)}, query_ids_size:{len(query_ids)}, {num_processes} processes"
f"using "
)
pool = Pool(processes=num_processes)
start_time = time.time()
chunk_results = pool.starmap(process_ordered_chunk, chunks)
flat_results = [item for chunk in chunk_results for item in chunk]
duration = time.time() - start_time
print(
f"Processed {len(generateds)} items in {duration:.2f} seconds "
f"using {num_processes} processes"
)
return flat_results
except KeyboardInterrupt:
print("\nReceived Ctrl+C, terminating processes...")
if pool is not None:
pool.terminate()
pool.join()
return []
except Exception as e:
print(f"Error occurred: {str(e)}")
return []
finally:
if shm is not None:
shm.close()
shm.unlink()
if pool is not None:
pool.close()
def process_ordered_chunk(
index,
shared_dict,
generateds,
query_ids,
test_case_batch_size,
) -> List[tuple[int, Any]]:
start = time.monotonic()
logger.info(
f"Process start at {start}s, chunk_index: {index}, chunk_size: {len(generateds)}, query_size: {len(query_ids)}"
)
try:
shm_name, shape, dtype = shared_dict
existing_shm = shared_memory.SharedMemory(name=shm_name)
buffer = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf)
id2info = pickle.loads(buffer.tobytes())
results = code_verify(
id2info, generateds, query_ids, test_case_batch_size=test_case_batch_size
)
if len(results) != len(generateds):
raise ValueError(
f"Result length mismatch: expected {len(generateds)}, got {len(results)}"
)
logger.info(f"Process {index} completed in {time.monotonic()-start:.2f}s")
return results
except pickle.UnpicklingError as e:
logger.error(f"Failed to deserialize shared memory: {e}")
return [str(e)] * len(query_ids)
except Exception as e:
logger.error(
f"Process {index} failed in {time.monotonic() - start:.2f}s, err: {str(e)}"
)
return [str(e)] * len(query_ids)
finally:
if "existing_shm" in locals():
existing_shm.close()
def load_jsonl(file_path: str):
"""Load JSONL file with validation"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def save_jsonl(samples, save_path):
with open(save_path, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print("Saved to", save_path)
def load_jsonl_stream(file_path):
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
yield json.loads(line)
def lcb_dataset_eval():
data4 = load_jsonl(
"/storage/openpsi/data/code/live_code_bench/live_code_bench_v4_v5-r1-distilled-prompt-fnname.jsonl"
)
id2info = defaultdict(dict)
for item in data4:
query_id = str(item["query_id"])
id2info[query_id] = item
def create_test_params(count=-1):
query_ids = []
generateds = []
cnt = 0
file_path = "/storage/openpsi/users/meijun.mei/datasets/Scenario.codegeneration_10_0.2_eval_all.json"
raw_data = []
with open(file_path, "r", encoding="utf-8") as f:
raw_data = [line for line in json.load(f)]
for d in raw_data:
if count > 0 and cnt >= count:
break
if not d["code_list"] or d["question_id"] not in id2info:
continue
generateds.extend(d["code_list"])
query_ids.extend([d["question_id"]] * len(d["code_list"]))
cnt += len(d["code_list"])
return generateds, query_ids
generateds, query_ids = create_test_params()
start_time = time.time()
scale = 2
result = parallel_code_verify(
id2info, generateds * scale, query_ids * scale, num_processes=16
)
# vals, metas =
print(f"Total results: {result}")
logger.info(
f"Process results: {result}, size: {len(generateds)}, in {time.time()-start_time:.2f}s"
)
def build_sol_id(query_id, solution_index):
return query_id + f"[solution{solution_index}]"
def parse_sol_id(sol_id):
query_id, solution_part = sol_id.split("[", 1)
solution_content = solution_part.rstrip("]")
return query_id, solution_content
def statics_result(result, query_ids):
result_statistics = defaultdict(
lambda: {"query_id": "", "pass": True, "solutions": []}
)
for i, query_id in enumerate(query_ids):
org_id, sol_idx = parse_sol_id(query_id)
result_statistics[org_id]["query_id"] = org_id
result_statistics[org_id]["solutions"].append({sol_idx: bool(result[i])})
result_statistics[org_id]["solutions"] = sorted(
result_statistics[org_id]["solutions"], key=lambda x: list(x.keys())[0]
)
if not result[i]:
result_statistics[org_id]["pass"] = False
return list(result_statistics.values())
def standard_dataset_eval(
dataset_path, code_count=0, test_case_batch_size=20, dry_run=False
):
id2info = defaultdict(dict)
generateds, query_ids = [], []
cnt = 0
testcase_in_dataset = 0
testcases_in_runtime = 0
request_size = 0
for item in load_jsonl_stream(dataset_path):
if code_count and cnt >= code_count:
break
if not item["solutions"]:
continue
generateds.extend(item["solutions"])
# set unique query_id for each solution code
for i in range(len(item["solutions"])):
query_id = build_sol_id(item["query_id"], i)
query_ids.append(query_id)
id2info[query_id] = copy.copy(item)
id2info[query_id]["query_id"] = query_id
# metrics
case_size = sys.getsizeof(item["input_output"])
assert (
case_size < 500 * 1024
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
cnt += len(item["solutions"])
case_size = len(json.loads(item["input_output"]).get("inputs", []))
testcase_in_dataset += case_size
testcases_in_runtime += case_size * len(item["solutions"])
request_size += math.ceil(case_size / test_case_batch_size) * len(
item["solutions"]
)
start_time = time.time()
logger.info(
f"Start process, code size: {len(generateds)}, request size: {request_size}, testcase_in_dataset: {testcase_in_dataset}, testcases_in_runtime: {testcases_in_runtime}"
)
if dry_run:
return
result = parallel_code_verify(
id2info, generateds, query_ids, test_case_batch_size, num_processes=16
)
# passed solutions
solution_pass_rate = result.count(1) / len(result)
logger.info(
f"Process results: {result}, code size: {len(generateds)},request size: {request_size}, testcase_in_dataset: {testcase_in_dataset}, testcases_in_runtime: {testcases_in_runtime}, solution_pass_rate:{solution_pass_rate} in {time.time()-start_time:.2f}s"
)
result_statistics = statics_result(result, query_ids)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_jsonl(
result_statistics,
os.path.basename(dataset_path) + f"{timestamp}.stat",
)
if __name__ == "__main__":
# lcb_dataset_eval()
standard_dataset_eval(
"/storage/openpsi/users/meijun.mei/datasets/loj_0410_format2.jsonl",
code_count=0,
dry_run=False,
)
# standard_dataset_eval(
# "/storage/openpsi/data/code/live_code_bench_for_test/live_code_bench_v4_v5-for-test-remote.jsonl",
# code_count=0,
# dry_run=False,
# )

View File

@ -0,0 +1 @@
{"task": "code", "query_id": "0001", "prompt": "", "solutions": [""], "input_output": "{\"inputs\": [\"2\", \"2\", \"2\"], \"outputs\": [\"4\\n\", \"4\\n\", \"4\\n\"], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"}

View File

@ -0,0 +1,6 @@
{"task": "code", "query_id": "0001", "prompt": "", "solutions": ["from typing import *\n\nclass Solution:\n def solveNQueens(self, n: int) -> List[List[str]]:\n def generateBoard():\n board = list()\n for i in range(n):\n row[queens[i]] = \"Q\"\n board.append(\"\".join(row))\n row[queens[i]] = \".\"\n return board\n\n def solve(row: int, columns: int, diagonals1: int, diagonals2: int):\n if row == n:\n board = generateBoard()\n solutions.append(board)\n else:\n availablePositions = ((1 << n) - 1) & (~(columns | diagonals1 | diagonals2))\n while availablePositions:\n position = availablePositions & (-availablePositions)\n availablePositions = availablePositions & (availablePositions - 1)\n column = bin(position - 1).count(\"1\")\n queens[row] = column\n solve(row + 1, columns | position, (diagonals1 | position) << 1, (diagonals2 | position) >> 1)\n\n solutions = list()\n queens = [-1] * n\n row = [\".\"] * n\n solve(0, 0, 0, 0)\n return solutions\n# Test case 1: Smallest case, n = 1\n# There is only one queen, so the only solution is a board with a single 'Q'.\nsolution = Solution()\nassert solution.solveNQueens(1) == [['Q']]\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"}
{"task": "code", "query_id": "0002", "prompt": "", "solutions": ["package main\n\nimport (\n\t\"bufio\"\n\t\"fmt\"\n\t\"os\"\n\t\"strconv\"\n\t\"strings\"\n)\n\nfunc main() {\n\tnums, target := arrayToOneStyleInput()\n\tresult := twoSum(nums, target)\n\tfmt.Println(result)\n}\n\nfunc twoSum(nums []int, target int) []int {\n\tprevNums := map[int]int{}\n\tfor i, num := range nums {\n\t\ttargetNum := target - num\n\t\ttargetNumIndex, ok := prevNums[targetNum]\n\t\tif ok {\n\t\t\treturn []int{targetNumIndex, i}\n\t\t} else {\n\t\t\tprevNums[num] = i\n\t\t}\n\t}\n\treturn []int{}\n}\n\nfunc arrayToOneStyleInput() ([]int, int) {\n\t// 读取数组\n\tscanner := bufio.NewScanner(os.Stdin)\n\tscanner.Scan()\n\tarrStr := strings.Trim(scanner.Text(), \"[]\")\n\tarr := stringToIntSlice(strings.ReplaceAll(arrStr, \",\", \" \"))\n\n\t// 读取目标值\n\tscanner.Scan()\n\ttarget, _ := strconv.Atoi(scanner.Text())\n\n\treturn arr, target\n}\n\nfunc stringToIntSlice(s string) []int {\n\tparts := strings.Split(s, \" \")\n\tres := make([]int, len(parts))\n\tfor i, p := range parts {\n\t\tres[i], _ = strconv.Atoi(p)\n\t}\n\treturn res\n}\n"], "input_output": "{\"inputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/input.txt\"], \"outputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/output.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "GO"}
{"task": "code", "query_id": "0003", "prompt": "", "solutions": ["public class TestMain {\n public static void main(String[] args) {\n assert \"test\".equals(\"test\");\n }\n}"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "JAVA"}
{"task": "code", "query_id": "0004", "prompt": "", "solutions": ["#include <iostream>\n#include <string>\nint main() {\n std::string name = \"Alice\";\n std::cout << \"Hello, \" << name << \"! Welcome to the world of C++!\\n\";\n return 0;\n}\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "CPP"}
{"task": "code", "query_id": "0005", "prompt": "", "solutions": ["import time\n\ndef square(num):\n return num ** 2\n\nresult=square(int(input()))\nprint(result)"], "input_output": "{\"inputs\": [\"2\", \"2\", \"2\"], \"outputs\": [\"4\\n\", \"4\\n\", \"4\\n\"], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"}
{"task": "code", "query_id": "0006", "prompt": "", "solutions": ["#include <iostream>\n\nint square(int number) {\n return number * number;\n}\n\nint main() {\n int num;\n std::cin >> num;\n\n int result = square(num);\n std::cout << result << std::endl;\n\n return 0;\n}"], "input_output": "{\"inputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_2.txt\"], \"outputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_4.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "CPP"}

View File

@ -20,5 +20,4 @@ from .api.core.model_api import (
PipelinableEngine,
ReaLModelConfig,
)
__version__ = "0.3.0"
from .version import __version__

View File

@ -1,3 +1,4 @@
import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Dict, List, Optional, Tuple, Type, Union
@ -304,12 +305,90 @@ class SGLangConfig:
# but we disable it to avoid precision issues
chunked_prefill_size: Optional[int] = -1
max_prefill_tokens: int = 32768
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
# logging
log_level: str = "info"
log_level_http: Optional[str] = "warning"
log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False
enable_metrics: bool = True # Exports Prometheus-like metrics
decode_log_interval: int = 1000 # How often (in tokens) to log decode progress.
# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
server_index,
base_gpu_id,
):
from realhf.base import constants, network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args.pop("hybrid_train")
args["random_seed"] = seeding.get_seed()
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
host=host,
model_path=model_path,
# Model and tokenizer
tokenizer_path=model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
is_embedding=False,
skip_tokenizer_init=True,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=base_gpu_id,
file_storage_path=os.path.join(
constants.SGLANG_CACHE_PATH,
f"sglang_storage{server_index}",
),
# Data parallelism
dp_size=1, # TODO: check whether we require SGLang dp
load_balance_method="round_robin",
# Expert parallelism
ep_size=1, # TODO: check
nnodes=1,
node_rank=0,
**args,
)
if pkg_version.is_version_less("sglang", "0.4.4"):
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')} ")
continue
flags.append(f"--{k.replace('_','-')} {v}")
flags = " ".join(flags)
return f"python3 -m sglang.launch_server {flags}"
@dataclass
class DistributedDataParallelConfig:
@ -317,7 +396,7 @@ class DistributedDataParallelConfig:
Refer to Megatron-LM documentation for details.
"""
grad_reduce_in_fp32: bool = False
grad_reduce_in_fp32: bool = True
overlap_grad_reduce: bool = True
overlap_param_gather: bool = False
align_param_gather: bool = False
@ -531,12 +610,24 @@ class PPOHyperparameters:
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
value_eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for value updates"}
)
early_stop_imp_ratio: float = field(
default=5.0, metadata={"help": "Early stop threshold for importance ratio"}
)
actor_sample_reuse: int = field(
default=1, metadata={"help": "The data reuse (aka PPO epoch) for actor."}
)
critic_sample_reuse: int = field(
default=1, metadata={"help": "The data reuse (aka PPO epoch) for critic."}
)
# Reward Processing
max_reward_clip: float = field(
@ -780,6 +871,10 @@ class BaseExperimentConfig:
"help": "Debug mode. False disables assertions for better performance."
},
)
metric_discovery_port: int = field(
default=0,
metadata={"help": "Discovery port for prometheus metrics service discovery."},
)
partition: str = field(
default="dev", metadata={"help": "SLURM partition for running the experiment."}
)
@ -845,6 +940,13 @@ class BaseExperimentConfig:
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'."
},
)
exclude: Optional[str] = field(
default=None,
metadata={
"help": "SLURM nodelist to exclude from allocation. "
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'."
},
)
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
cache_clear_freq: Optional[int] = field(
default=10,
@ -886,6 +988,58 @@ class BaseExperimentConfig:
mem_per_model_worker: int = field(
default=90000, metadata={"help": "Memory per model worker (MB)."}
)
shuffle_dataset: bool = field(
default=True, metadata={"help": "Shuffle in each epoch."}
)
## Configuration options of asynchronous experiments. ##
@dataclass
class AsyncRLOptions:
new_tokens_per_chunk: int = field(
default=1024,
metadata={"help": "The lenght of chunked generation."},
)
max_head_offpolicyness: int = field(
default=0,
metadata={"help": "Maximum off-policyness tolerance for the first token."},
)
n_rollout_workers: Optional[int] = field(
default=None,
metadata={
"help": "Number of rollout workers. None defaults to train world size."
},
)
max_concurrent_rollouts: int = field(
default=1024,
metadata={"help": "Max concurrent rollout jobs in each worker."},
)
flush_request_timeout: int = field(
default=120,
metadata={"help": "The timeout of flushing requests upon weight update."},
)
cpus_per_generation_server: int = field(
default=4, metadata={"help": "Generation server CPUs."}
)
mem_per_generation_server: int = field(
default=60 * 1024, metadata={"help": "Generation server CPU memory in MB."}
)
cpus_per_gserver_manager: int = field(
default=4, metadata={"help": "Generation manager CPUs."}
)
mem_per_gserver_manager: int = field(
default=10 * 1024, metadata={"help": "Generation manager CPU memory in MB."}
)
cpus_per_rollout_worker: int = field(
default=4, metadata={"help": "Rollout worker CPUs."}
)
mem_per_rollout_worker: int = field(
default=20 * 1024, metadata={"help": "Rollout worker CPU memory in MB."}
)
## Configurations for practical experiments. ##
@ -1058,6 +1212,67 @@ class PPOMATHExperimentOptions:
default=0.0, metadata={"help": "Maximum percentage of dataset to each filter."}
)
success_rate_ub: float = field(
default=1.0,
metadata={
"help": "Success rate higher than this value will be filtered out after generation. Valid for async training."
},
)
success_rate_lb: float = field(
default=0.0,
metadata={
"help": "Success rate lower than this value will be filtered out after generation. Valid for async training."
},
)
@dataclass
class MathCodeEvalOptions:
gen_config: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
actor: ModelTrainEvalConfig = field(
default_factory=ModelTrainEvalConfig,
metadata={"help": "Primary LLM configuration."},
)
rew: ModelTrainEvalConfig = field(
default_factory=ModelTrainEvalConfig,
metadata={"help": "Reward model configuration."},
)
actor_gen: MFCConfig = field(
default_factory=MFCConfig, metadata={"help": "Rollout MFC configuration."}
)
rew_inf: MFCConfig = field(
default_factory=MFCConfig, metadata={"help": "InfReward MFC configuration."}
)
dataset: PromptOnlyDatasetConfig = field(
default_factory=PromptOnlyDatasetConfig,
metadata={"help": "Dataset configuration."},
)
group_size: int = field(
default=1,
metadata={"help": "Number of answers retained per prompt (best-of-n)."},
)
rw_type: Optional[str] = field(
default="sparse",
metadata={
"help": "Type of reward processing. Only `sparse` is valid for now.",
"choices": ["sparse"],
},
)
check_xml_format: bool = field(
default=False, metadata={"help": "Validate XML format in generated responses."}
)
check_verifier_status: bool = field(
default=False,
metadata={"help": "Raise error if reward is all-zero (verifier bug check)."},
)
## A helper function to visualize the helper messages. ##
from rich.console import Console

View File

@ -0,0 +1,36 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
# Experimental APIs of RL agents.
import asyncio
from abc import ABC
from typing import List
from realhf.api.core.config import AgentAbstraction
from realhf.api.core.data_api import SequenceSample
from realhf.api.core.env_api import EnvironmentService
class Agent(ABC):
# TODO: implement type checking inside each queue.
async def collect_trajectory(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
) -> List[SequenceSample]:
raise NotImplementedError()
ALL_AGNETS = {}
def register_agent(name, cls_):
assert name not in ALL_AGNETS
ALL_AGNETS[name] = cls_
def make_agent(cfg: AgentAbstraction) -> Agent:
return ALL_AGNETS[cfg.type_](**cfg.args)

View File

@ -16,6 +16,18 @@ class DatasetAbstraction:
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class EnvServiceAbstraction:
type_: str = "null"
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class AgentAbstraction:
type_: str = "null"
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class ModelWrapperAbstraction:
type_: str

View File

@ -39,12 +39,12 @@ from pydantic import field_validator, model_validator
from realhf.api.cli_args import MicroBatchSpec
from realhf.api.core import config as config_api
from realhf.base import constants, datapack, logging
from realhf.base import constants, datapack, logging, seeding
from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("api.data")
RL_TASKS = ["math", "code", "rlhf"]
RL_TASKS = ["math", "code", "rlhf", "stem"]
def load_hf_tokenizer(
@ -181,11 +181,22 @@ class SequenceSample:
@field_validator("ids")
@classmethod
def _validate_ids(cls, ids: List[Hashable]) -> List[Hashable]:
def _validate_ids(cls, ids: List[Hashable]) -> List[str]:
ids = list(map(str, ids))
if len(ids) != len(set(ids)):
raise ValueError(f"IDs contain duplicates: {ids}.")
return ids
@field_validator("trailing_shapes")
@classmethod
def _validate_trailing_shapes(
cls, trailing_shapes: Dict
) -> Dict[str, Tuple | None]:
for k, v in trailing_shapes.items():
if v is not None:
trailing_shapes[k] = tuple(v)
return trailing_shapes
@field_validator("keys")
@classmethod
def _validate_keys_type(cls, keys: Iterable) -> Set[str]:
@ -496,6 +507,18 @@ class SequenceSample:
self.seqlens.update(other.seqlens)
self.metadata.update(other.metadata)
@staticmethod
def shuffled(sample: "SequenceSample") -> "SequenceSample":
"""Create a shuffled sample.
Define it as a staticmethod because it is an out-of-place operation.
(Think about the difference between `sorted` and `l.sort()`).
"""
seed = seeding.get_shuffle_seed()
rng = np.random.RandomState(seed)
indices = np.arange(sample.bs)
rng.shuffle(indices)
return SequenceSample.reorder(sample, indices)
@staticmethod
def _resolve_seqlen_from_key(key, seqlens: List[int]) -> List[torch.Tensor]:
if key in [
@ -656,6 +679,44 @@ class SequenceSample:
finally:
cls.__init__ = original_init
def as_json_compatible(self) -> Dict:
return dict(
ids=self.ids,
keys=list(self.keys),
trailing_shapes={
k: tuple(v) if v is not None else None
for k, v in self.trailing_shapes.items()
},
dtypes={k: str(v) if v is not None else v for k, v in self.dtypes.items()},
seqlens=self.seqlens,
data={
k: v.cpu().numpy().tolist() if v is not None else None
for k, v in self.data.items()
},
metadata=self.metadata,
)
@classmethod
def from_json_compatible(cls, data: Dict):
dtypes = {}
for k, dtype_str in data["dtypes"].items():
if dtype_str is not None:
dtypes[k] = getattr(torch, dtype_str.split(".")[1])
else:
dtypes[k] = None
return cls(
ids=data["ids"],
keys=set(data["keys"]),
trailing_shapes=data["trailing_shapes"],
dtypes=dtypes,
seqlens=data["seqlens"],
data={
k: torch.tensor(v, dtype=dtypes[k]) if v is not None else v
for k, v in data["data"].items()
},
metadata=data["metadata"],
)
@dataclasses.dataclass
class DataBatchMeta:
@ -824,3 +885,23 @@ def gather_stat(src: List[Dict]) -> Dict:
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
)
return res
def tabulate_stats(data: Dict[str, float], col=4, floatfmt=".4e") -> str:
from tabulate import tabulate
items = list(data.items())
# Calculate how many rows we'll need
row_count = (len(items) + col - 1) // col
# Reorganize items in column-major order
column_major = []
for i in range(row_count):
row = []
for j in range(col):
index = i + j * row_count
if index < len(items):
row.extend(items[index])
column_major.append(row)
return tabulate(column_major, floatfmt=floatfmt, tablefmt="fancy_grid")

View File

@ -0,0 +1,47 @@
import abc
import asyncio
from typing import Any, Dict, List, Tuple
from realhf.api.core.config import EnvServiceAbstraction
class EnvironmentService(abc.ABC):
# TODO: import gymnasium, use its types and signatures
async def step(self, action: Any) -> Tuple[Any, Any, bool, bool, Dict]:
# obs, reward, terminated, truncated, info
raise NotImplementedError()
async def reset(self, seed=None, options=None) -> Tuple[Any, Dict]:
# obs, info
raise NotImplementedError()
ALL_ENV_CLASSES = {}
def register_environment(name, env_cls):
assert name not in ALL_ENV_CLASSES
assert "/" not in name
ALL_ENV_CLASSES[name] = env_cls
class NullEnvironment:
async def step(self, action):
await asyncio.sleep(1)
# obs, reward, terminated, truncated, info
return None, 0.0, True, False, {}
async def reset(self, seed=None, options=None) -> Tuple[Any, Dict]:
await asyncio.sleep(0.1)
return None, {}
register_environment("null", NullEnvironment)
def make_env(
cfg: EnvServiceAbstraction,
) -> EnvironmentService:
return ALL_ENV_CLASSES[cfg.type_](**cfg.args)

View File

@ -33,47 +33,118 @@ class ZeroTotalLossWeightException(Exception):
pass
@dataclasses.dataclass
class GenRespMeta:
qid: str
accepted: bool
@dataclasses.dataclass
class GenReqMeta:
## Meta info used to schedule the request. ##
prompt_len: int
group_size: int
new_token_budget: int
predicted_new_tokens: int | None
@dataclasses.dataclass
class ModelVersionReq:
server_url: str
@dataclasses.dataclass
class APIGenerateInput:
# The unique query id of this prompt
qid: Hashable
group_idx: int
# prompt token ids
prompt_ids: List[int]
# prompt token ids + generated prefix, the input to server
input_ids: List[int]
# the sampling params to server, may limit n=1 and max_new_tokens
# for partial rollout
gconfig: GenerationHyperparameters
# stop tokens, usually EOS and PAD
stop_token_ids: List[int] = dataclasses.field(default_factory=list)
return_logprob: bool = False
# whether to return logprobs
return_logprob: bool = True
# logprobs of preivous generation
# length len(input_ids) - len(prompt_ids)
prev_logprobs: List[float] = dataclasses.field(default_factory=list)
# the weight version when submitting this request
version_start: int = -1
# other metadata
metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class APIGenerateOutput:
## input re-export ##
qid: Hashable
group_idx: int
prompt_ids: List[int]
input_ids: List[int]
output_ids: List[int] = dataclasses.field(default_factory=list)
output_logprobs: List[int] = dataclasses.field(default_factory=list)
no_eos: bool = True
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
gconfig: GenerationHyperparameters
prev_logprobs: List[float] = dataclasses.field(default_factory=list)
version_start: int = -1
metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)
## outputs. To be amended by the reply. ##
# output token ids
output_ids: List[List[int]] = dataclasses.field(default_factory=list)
# output logprobs with the same length as output_ids
output_logprobs: List[List[float]] = dataclasses.field(default_factory=list)
# the weight version when finishing this request
version_end: List[int] = dataclasses.field(default_factory=list)
# whether truncated
no_eos: List[bool] = dataclasses.field(default_factory=list)
# statistics
latency: float = float("inf")
ttft: float = float("inf") # Time to first token
itl: List[float] = dataclasses.field(
default_factory=list
) # List of inter-token latencies
error: str = ""
@classmethod
def from_input(cls, inp: APIGenerateInput):
return cls(
qid=inp.qid,
group_idx=inp.group_idx,
prompt_ids=inp.prompt_ids,
input_ids=inp.input_ids,
gconfig=inp.gconfig,
prev_logprobs=inp.prev_logprobs,
version_start=inp.version_start,
metadata=inp.metadata,
)
@staticmethod
def concat(outputs: List["APIGenerateOutput"]):
return APIGenerateOutput(
qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids,
input_ids=outputs[0].input_ids,
gconfig=outputs[0].gconfig,
prev_logprobs=outputs[0].prev_logprobs,
version_start=outputs[0].version_start,
metadata=outputs[0].metadata,
output_ids=sum([o.output_ids for o in outputs], []),
output_logprobs=sum([o.output_logprobs for o in outputs], []),
version_end=sum([o.version_end for o in outputs], []),
no_eos=sum([o.no_eos for o in outputs], []),
latency=max([o.latency for o in outputs]),
ttft=max([o.ttft for o in outputs]),
itl=sum([o.itl for o in outputs], []),
)
@property
def output_len(self):
def group_size(self):
return len(self.output_ids)
@property
def output_lens(self):
return [len(x) for x in self.output_ids]
@property
def input_len(self):
return len(self.input_ids)
@ -83,27 +154,75 @@ class APIGenerateOutput:
return len(self.prompt_ids)
@property
def gen_len(self):
return self.output_len + self.input_len - self.prompt_len
def gen_lens(self):
return [len(x) + self.input_len - self.prompt_len for x in self.output_ids]
def get_logprobs(self) -> List[List[float]]:
logprobs = []
for logp in self.output_logprobs:
assert len(self.prev_logprobs) == self.input_len - self.prompt_len, (
len(self.prev_logprobs),
self.input_len,
self.prompt_len,
)
logprobs.append([0.0] * (self.prompt_len - 1) + self.prev_logprobs + logp)
return logprobs
@dataclasses.dataclass
class BundledGenerationOutputs:
## Used for collecting generation outputs for env interaction or training. ##
# unique query id in the dataset
qid: Hashable
# prompt token ids
prompt_ids: List[int]
# output token ids excluding the prompt
output_ids: List[List[int]]
# whole sequences including the prompt
seqs: List[List[int]]
# whole logprobs, one token shorter than seq
# logps at prompt tokens are zero
logprobs: List[List[float]]
# whether truncated
no_eos: List[bool]
# server weight version when starting generation
version_start: List[int]
# server weight version when generation ends
version_end: List[int]
@classmethod
def from_single(cls, outputs: List[APIGenerateOutput]):
def from_api_outputs(cls, outputs: List[APIGenerateOutput]):
assert len(set(o.qid for o in outputs)) == 1
prompt_len = len(outputs[0].prompt_ids)
seqs = []
logprobs = []
version_starts = []
for o in outputs:
for out in o.output_ids:
seqs.append(o.input_ids + out)
for logp in o.get_logprobs():
logprobs.append(logp)
version_starts += [o.version_start] * o.group_size
return cls(
qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids,
seqs=[o.input_ids + o.output_ids for o in outputs],
no_eos=[o.no_eos for o in outputs],
seqs=seqs,
output_ids=[seq[prompt_len:] for seq in seqs],
logprobs=logprobs,
no_eos=sum([o.no_eos for o in outputs], []),
version_start=version_starts,
version_end=sum([o.version_end for o in outputs], []),
)
@property
def output_logprobs(self):
return [lp[self.prompt_len - 1 :] for lp in self.logprobs]
@property
def output_lens(self):
return [len(out) for out in self.output_ids]
@property
def seqlens(self):
return [len(seq) for seq in self.seqs]
@ -113,7 +232,10 @@ class BundledGenerationOutputs:
return len(self.prompt_ids)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(
total=6 * 60 * 60,
connect=300,
)
class LLMAPIClient:
@ -128,7 +250,7 @@ class LLMAPIClient:
self.semaphore: asyncio.Semaphore
async def __aenter__(self):
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300, force_close=True)
self.session = aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT,
connector=conn,
@ -391,11 +513,11 @@ class PipelinableEngine(abc.ABC):
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
loss_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor],
loss_weight_fn: Callable[[torch.Tensor, SequenceSample], float],
version_steps: int,
token_normalize_scope: Literal["global", "dp"] = "global",
) -> Tuple[torch.Tensor, Dict] | None:
) -> Dict:
"""Update the model with a batch of data and a loss function.
:param input_: The input data. It should contain at least the key ``packed_input_ids``,
@ -403,8 +525,8 @@ class PipelinableEngine(abc.ABC):
entries required to compute the loss.
:type input_: SequenceSample
:param loss_fn: The loss function. It takes the output of the forward pass and the
input data, returning the loss and a dictionary of statistics.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
input data, returning the loss.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor]
:param loss_weight_fn: This function is used to calculate the number of valid tokens
when normalizing loss across micro batches and DP ranks. Can be `lambda: 1`
if just taking the average over batches.
@ -412,12 +534,6 @@ class PipelinableEngine(abc.ABC):
:param version_steps: The global step counter for this experiment,
used by the backend to determine the learning rate schedule.
:type version_steps: int
:param num_micro_batches: The number of micro-batches to split the batch into.
Gradients will be accumulated across micro-batches, and only one update will
occur. For pipelined training, micro-batches are processed together by the engine,
which automatically schedules the forward and backward passes. For non-pipelined
training, forward and backward passes are executed iteratively over mini-batches
to accumulate gradients. If None, the batch will not be split.
:param global_normalize_scope: The scope of token-wise loss normalization. Choices:
global: average across all micro batches across DP ranks.
dp: average across micro batches in current DP rank.
@ -431,8 +547,8 @@ class PipelinableEngine(abc.ABC):
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]],
) -> Tuple[torch.Tensor, Dict] | None:
loss_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor],
) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function.
This method wraps :meth:`forward` with a customized ``post_hook`` and ``aggregate_fn``.
@ -442,22 +558,21 @@ class PipelinableEngine(abc.ABC):
entries required to compute the loss.
:type input_: SequenceSample
:param loss_fn: The loss function. It takes the output of the forward pass and the
input data, returning the loss and a dictionary of statistics.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], Tuple[torch.Tensor, Dict]]
:return: The aggregated scalar loss and a dictionary of statistics from the last pipeline
stage. Returns None otherwise.
:rtype: Tuple[torch.Tensor, Dict]
input data, returning the loss.
:type loss_fn: Callable[[torch.Tensor, SequenceSample], torch.Tensor]
:return: The aggregated scalar loss if on the last pipe stage.
:rtype: torch.Tensor | None
"""
def agg(xs: List[Tuple[torch.Tensor, Dict]]):
losses, stats = zip(*xs)
return sum(losses), {k: sum(s[k] for s in stats) for k in stats[0].keys()}
def _loss_fn(out, inp_):
# To prevent calling data reordering.
return float(loss_fn(out, inp_))
return self.forward(
input_=input_,
mb_spec=mb_spec,
post_hook=loss_fn,
aggregate_fn=agg,
post_hook=_loss_fn,
aggregate_fn=sum,
)
def forward(
@ -511,11 +626,6 @@ class PipelinableEngine(abc.ABC):
:type tokenizer: transformers.PreTrainedTokenizerFast
:param gconfig: The generation hyperparameters.
:type gconfig: GenerationHyperparameters
:param num_micro_batches: The number of micro-batches to split the batch into.
Regardless of pipelining, mini-batches will be processed one-by-one by the module.
This approach helps reduce GPU memory usage for hidden states and KV-caches.
If None, the batch will not be split.
:type num_micro_batches: Optional[int]
:return: For the last pipeline stage, returns the generated tokens, log probabilities, and optionally the logits mask.
See :class:`GenerationHyperparameters` for more details about the logits mask.
Returns None for other stages.

View File

@ -4,7 +4,7 @@
import dataclasses
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import realhf.api.core.dfg as dfg
from realhf.api.cli_args import (
@ -14,7 +14,9 @@ from realhf.api.cli_args import (
WandBConfig,
)
from realhf.api.core.config import (
AgentAbstraction,
DatasetAbstraction,
EnvServiceAbstraction,
ModelAbstraction,
ModelName,
ModelShardID,
@ -23,9 +25,6 @@ from realhf.api.core.config import (
from realhf.base import constants, topology
from realhf.base.cluster import spec as cluster_spec
_LLM_GPU_IMAGE = cluster_spec.gpu_image
_LLM_CPU_IMAGE = cluster_spec.cpu_image
@dataclasses.dataclass
class Scheduling:
@ -36,7 +35,7 @@ class Scheduling:
node_type: str = None
nodelist: str = None
exclude: str = None
container_image: str = _LLM_CPU_IMAGE
container_image: str = cluster_spec.cpu_image
env_vars: Dict[str, str] = dataclasses.field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format
@ -50,7 +49,7 @@ class Scheduling:
"cpu": 16,
"mem": 20 * 1024,
"gpu": 0,
"container_image": _LLM_CPU_IMAGE,
"container_image": cluster_spec.cpu_image,
**kwargs,
}
)
@ -62,7 +61,43 @@ class Scheduling:
"cpu": 2,
"gpu": 1,
"mem": 60 * 1024,
"container_image": _LLM_GPU_IMAGE,
"container_image": cluster_spec.gpu_image,
**kwargs,
}
)
@staticmethod
def generation_server_default(**kwargs):
return Scheduling(
**{
"cpu": 4,
"gpu": 1,
"mem": 60 * 1024,
"container_image": cluster_spec.gpu_infer_image,
**kwargs,
}
)
@staticmethod
def gserver_manager_default(**kwargs):
return Scheduling(
**{
"cpu": 4,
"gpu": 0,
"mem": 10 * 1024,
"container_image": cluster_spec.gpu_image,
**kwargs,
}
)
@staticmethod
def rollout_worker_default(**kwargs):
return Scheduling(
**{
"cpu": 4,
"gpu": 0,
"mem": 20 * 1024,
"container_image": cluster_spec.gpu_image,
**kwargs,
}
)
@ -119,6 +154,7 @@ class ModelWorker:
datasets: Optional[List[Union[str, DatasetAbstraction]]] = None
use_dataset_cache: bool = False
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
shuffle_dataset: bool = True
cuda_cache_cleanliness: bool = True
cuda_cache_clear_freq: int = 10
torch_cache_mysophobia: bool = False
@ -126,8 +162,8 @@ class ModelWorker:
model_rpcs: List[dfg.MFCDef] = None
model_topos: Dict[ModelName, topology.ProcessTopology] = None
msid2mwid: Dict[ModelShardID, int] = None
data_transfer_pairs: List[Tuple[str, str]] = None
sync_param_pairs: List[Tuple[str, str]] = None
data_transfer_pairs: List[Tuple[ModelName, ModelName]] = None
sync_param_pairs: List[Tuple[ModelName, ModelName]] = None
# profiling
profile_mode: bool = False
worker_info: Optional[WorkerInformation] = None
@ -140,12 +176,50 @@ class ModelWorker:
)
@dataclasses.dataclass
class GenerationServer:
base_seed: int
backend_type: str
backend_args: Any
model_path: str
tp_size: int
worker_info: WorkerInformation = None
@dataclasses.dataclass
class GserverManager:
model_name: ModelName
n_servers: int
schedule_policy: str
max_head_offpolicyness: int
train_batch_size: int
flush_request_timeout: int
max_concurrent_rollouts: int
worker_info: WorkerInformation = None
@dataclasses.dataclass
class RolloutWorker:
base_seed: int
model_name: ModelName
tokenizer_path: str
new_tokens_per_chunk: int
rollout_request_timeout: int
env: EnvServiceAbstraction
agent: AgentAbstraction
datasets: List[Union[str, DatasetAbstraction]]
use_dataset_cache: bool = False
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
worker_info: WorkerInformation = None
@dataclasses.dataclass
class MasterWorker:
base_seed: int
exp_ctrl: ExperimentSaveEvalControl
# main components
n_model_workers: int
shuffle_dataset: bool = True
model_rpcs: List[dfg.MFCDef] = None
model_topos: Dict[ModelName, topology.ProcessTopology] = None
msid2mwid: Dict[ModelShardID | str, int] = None
@ -162,13 +236,12 @@ class TasksGroup:
@dataclasses.dataclass
class ExperimentScheduling:
model_worker: Union[List[TasksGroup], TasksGroup] = dataclasses.field(
default_factory=list
)
master_worker: Union[List[TasksGroup], TasksGroup] = dataclasses.field(
default_factory=list
)
controller_image: str = _LLM_CPU_IMAGE
model_worker: TasksGroup
master_worker: TasksGroup
generation_server: TasksGroup | None = None
gserver_manager: TasksGroup | None = None
rollout_worker: TasksGroup | None = None
controller_image: str = cluster_spec.cpu_image
@dataclasses.dataclass
@ -179,6 +252,9 @@ class ExperimentConfig:
# dataflow
model_rpcs: List[dfg.MFCDef]
model_worker: List[ModelWorker] = dataclasses.field(default_factory=list)
generation_server: List[GenerationServer] = dataclasses.field(default_factory=list)
gserver_manager: List[GserverManager] = dataclasses.field(default_factory=list)
rollout_worker: List[RolloutWorker] = dataclasses.field(default_factory=list)
# master_worker will be set automatically
master_worker: Optional[List[MasterWorker]] = None
# automatic evaluation
@ -193,6 +269,7 @@ class ExperimentConfig:
base_seed=self.model_worker[0].base_seed,
exp_ctrl=self.exp_ctrl,
n_model_workers=len(self.model_worker),
shuffle_dataset=self.model_worker[0].shuffle_dataset,
)
]
@ -259,12 +336,12 @@ class ExperimentConfig:
return getattr(self, worker_type)[worker_index]
def set_worker_information(self, experiment_name, trial_name):
if len(self.model_worker) > 0:
assert len(self.master_worker) == 1
for worker_type, workers in [
("model_worker", self.model_worker),
("master_worker", self.master_worker),
("gserver_manager", self.gserver_manager),
("rollout_worker", self.rollout_worker),
("generation_server", self.generation_server),
]:
if len(workers) == 0:
continue

View File

@ -3,7 +3,6 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import json
import math
from typing import List, Optional, Tuple, Union
@ -12,11 +11,7 @@ import numpy as np
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef
from realhf.base.cluster import spec as cluster_spec
from realhf.base.slurm_utils import (
are_ones_contiguous,
nodelist_from_nodes,
parse_nodelist,
)
from realhf.base.slurm_utils import are_ones_contiguous, parse_nodelist
@dataclasses.dataclass
@ -27,22 +22,12 @@ class DeviceMesh:
# a 2D binary array of current device mesh name
# shape: (n_nodes, n_gpus_per_node)
mapping: np.ndarray
# For slurm cluster: nodelist string of all
# allocated nodes in the cluster
global_mesh_name: str = None
# For slurm cluster: nodelist string this device mesh
name: str = None
# cluster info, GPU memory cap in bytes
gpu_memory_capacity: int = 80 * (1024**3)
def to_dict(self):
return dict(
n_nodes=self.n_nodes,
n_gpus_per_node=self.n_gpus_per_node,
mapping=self.mapping.tolist(),
global_mesh_name=self.global_mesh_name,
name=self.name,
gpu_memory_capacity=self.gpu_memory_capacity,
)
def split(self, split_n_gpus: int) -> Tuple["DeviceMesh", "DeviceMesh"]:
@ -73,15 +58,11 @@ class DeviceMesh:
n_nodes=self.n_nodes,
n_gpus_per_node=self.n_gpus_per_node,
mapping=sub_mapping1,
global_mesh_name=self.global_mesh_name,
name=device_mesh_name_from_mapping(self.global_mesh_name, sub_mapping1),
),
DeviceMesh(
n_nodes=self.n_nodes,
n_gpus_per_node=self.n_gpus_per_node,
mapping=sub_mapping2,
global_mesh_name=self.global_mesh_name,
name=device_mesh_name_from_mapping(self.global_mesh_name, sub_mapping2),
),
)
assert d1._is_valid_mapping()
@ -96,34 +77,12 @@ class DeviceMesh:
def __post_init__(self):
n = cluster_spec.suffix_n_digits
if self.global_mesh_name is None:
self.global_mesh_name = (
f"{cluster_spec.node_name_prefix}[{1:0{n}d}-{self.n_nodes:0{n}d}]"
if self.n_nodes > 1
else f"{cluster_spec.node_name_prefix}{1:0{n}d}"
)
if self.global_mesh_name is not None and self.name is None:
self.name = device_mesh_name_from_mapping(
self.global_mesh_name, self.mapping
)
assert self._is_valid_mapping()
def __eq__(self, other: "DeviceMesh"):
assert (
self.global_mesh_name is None
or self.global_mesh_name == other.global_mesh_name
), "Only device meshes that on the same cluster mesh is comparable"
return np.all(self.mapping == other.mapping)
def __repr__(self):
return f"DeviceMesh({self.name} in {self.global_mesh_name})"
def __op_assertion(self, other: "DeviceMesh"):
assert (
self.global_mesh_name is None
or self.global_mesh_name == other.global_mesh_name
), "operation only support device meshes on the same cluster nodes"
assert self.n_nodes == other.n_nodes
assert self.n_gpus_per_node == other.n_gpus_per_node
@ -184,8 +143,6 @@ class DeviceMesh:
n_nodes=self.n_nodes,
n_gpus_per_node=self.n_gpus_per_node,
mapping=sub_mapping,
global_mesh_name=self.global_mesh_name,
name=device_mesh_name_from_mapping(self.global_mesh_name, sub_mapping),
)
for sub_mapping in sub_mappings
]
@ -193,24 +150,15 @@ class DeviceMesh:
def _is_valid_mapping(self) -> bool:
if self.mapping.shape != (self.n_nodes, self.n_gpus_per_node):
raise RuntimeError(
f"Invalid mapping shape {self.mapping.shape} " f"{self.name}"
f"Invalid mapping shape {self.mapping.shape} n_nodes={self.n_nodes} "
f"n_gpus_per_node={self.n_gpus_per_node}"
)
if not np.all(np.logical_or(self.mapping == 0, self.mapping == 1)):
raise RuntimeError(f"Invalid mapping value {self.mapping}")
assert math.log(self.n_gpus_per_node, 2).is_integer()
one_node_valid_gpus = [
2**i for i in range(int(math.log(self.n_gpus_per_node, 2)))
]
if self.mapping.sum() < self.n_gpus_per_node:
if not any(self.mapping.sum() == g for g in one_node_valid_gpus):
raise RuntimeError(
f"Invalid mapping {self.mapping}. "
"If using GPUs less than an entire node, "
"only 1, 2, 4, 8, ... GPUs are allowed."
)
else:
if self.mapping.sum() >= self.n_gpus_per_node:
if not (
self.mapping.sum() % self.n_gpus_per_node == 0
and np.all(
@ -268,30 +216,9 @@ def make_device_mesh_from_name(
n_nodes=n_nodes,
n_gpus_per_node=n_gpus_per_node,
mapping=mapping,
global_mesh_name=global_mesh_name,
name=name,
)
def device_mesh_name_from_mapping(global_mesh_name: str, mapping: np.ndarray):
prefix = cluster_spec.node_name_prefix
node_list = parse_nodelist(global_mesh_name, prefix)
n_nodes = len(node_list)
n_gpus_per_node = mapping.shape[1]
assert mapping.shape[0] == n_nodes
node_indices, gpu_ids = np.where(mapping == 1)
if np.sum(mapping) < n_gpus_per_node:
node_name = node_list[node_indices[0]]
gpu_ids = list(map(str, gpu_ids))
return f"{node_name}:{','.join(gpu_ids)}"
else:
unique_node_indices = np.unique(node_indices)
sub_node_list = [node_list[i] for i in unique_node_indices]
node_name = nodelist_from_nodes(sub_node_list, prefix)
return node_name
def find_parallel_strategies(
device_mesh: DeviceMesh,
) -> List[ParallelismConfig]:

View File

@ -7,6 +7,7 @@ import getpass
import os
import re
import time
import uuid
from typing import Dict, List, Optional
import realhf.api.core.system_api as config_package
@ -19,6 +20,7 @@ import realhf.scheduler.client as sched_client
import realhf.system as system
from realhf.scheduler.client import JobException, JobState
from realhf.scheduler.evaluator import AutomaticEvaluator
from realhf.version import get_full_version_with_dirty_description
logger = logging.getLogger("main", "system")
@ -44,6 +46,8 @@ def _submit_workers(
scheduled_jobs = []
for sch_cfg in scheduling_configs:
if sch_cfg is None:
continue
job_environs = {**environs, **sch_cfg.scheduling.env_vars}
cmd = sched_client.remote_worker_cmd(expr_name, trial_name, debug, worker_type)
@ -78,7 +82,12 @@ def _submit_workers(
return scheduled_jobs
def main_start(args, recover_count: int = 0):
def main_start(args, job_group_id: str = "", recover_count: int = 0):
if not job_group_id:
job_group_id = str(uuid.uuid4())
logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}")
logger.info(f"AReaL Job Group ID: {job_group_id}")
logger.info(f"AReaL Job Group Index: {recover_count}")
if recover_count == 0:
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
experiment = config_package.make_experiment(args.experiment_name)
@ -177,6 +186,8 @@ def main_start(args, recover_count: int = 0):
trial_name=trial_name,
schedule_strategy=args.schedule_strategy,
evaluator=evaluator,
job_group_id=job_group_id,
job_group_index=recover_count,
)
setup = experiment.scheduling_setup()
@ -198,7 +209,10 @@ def main_start(args, recover_count: int = 0):
raise e
logger.info(f"Resetting name resolving repo... Done.")
logger.info(f"Running configuration: {experiment.__class__.__name__}")
logger.info(
f"Running configuration: {experiment.__class__.__name__}. "
f"The current recover retry: {recover_count + 1}/{args.recover_retries}"
)
# Schedule controller
if args.mode == "ray":
@ -281,7 +295,7 @@ def main_start(args, recover_count: int = 0):
f"total recover count {args.recover_retries}"
)
time.sleep(args.recover_after)
main_start(args, recover_count=recover_count + 1)
main_start(args, job_group_id=job_group_id, recover_count=recover_count + 1)
else:
raise e

View File

@ -25,12 +25,17 @@ from realhf.base.prologue import (
get_experiment_name,
get_trial_name,
)
from realhf.version import get_full_version_with_dirty_description
# NOTE: Register all implemented experiments inside ReaL.
import_module(
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "common"),
re.compile(r".*_exp\.py$"),
)
import_module(
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "async_exp"),
re.compile(r".*_exp\.py$"),
)
import realhf.experiments.benchmark.profile_exp
@ -63,6 +68,10 @@ def print_help(exp_type):
console.print("\n[dim]Use [bold]--help[/bold] to show this message again[/dim]")
def print_version():
console.print(f"AReaL Version: {get_full_version_with_dirty_description()}")
def main():
# Create parser with add_help=False to disable automatic --help
parser = argparse.ArgumentParser(prog="ReaL Quickstart", add_help=False)
@ -71,6 +80,7 @@ def main():
parser.add_argument(
"--help", action="store_true", help="Show this help message and exit"
)
parser.add_argument("--version", action="store_true", help="Show AReaL version")
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
subparsers.required = True
@ -94,6 +104,10 @@ def main():
# Parse known args first to check for help
args = vars(parser.parse_known_args()[0])
if args["version"]:
print_version()
return
# Handle help at both main and subcommand levels
if args["help"]:
if args["cmd"]:

View File

@ -25,6 +25,7 @@ from realhf.api.quickstart.entrypoint import (
QUICKSTART_EXPR_CACHE_PATH,
)
from realhf.base import cluster, gpu_utils, importing, logging, name_resolve, names
from realhf.version import get_full_version_with_dirty_description
logger = logging.getLogger("Main-Workers")
@ -63,7 +64,7 @@ def main_worker(args):
worker_index_start + args.wprocs_per_jobstep,
args.wprocs_in_job + args.wproc_offset,
)
if args.worker_type == "master_worker":
if args.worker_type in ["master_worker", "rollout_worker", "gserver_manager"]:
try:
# CUDA_VISIBLE_DEVICES is set by slurm on PPU nodes
# we need to remove it on CPU workers
@ -250,6 +251,7 @@ def main():
args = parser.parse_args()
logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}")
args.func(args)

View File

@ -45,6 +45,7 @@ class ClusterSpec:
self.__gpu_type = spec.get("gpu_type", None)
self.__default_mount = spec.get("default_mount", None)
self.__gpu_image = spec.get("gpu_image", None)
self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image)
self.__cpu_image = spec.get("cpu_image", None)
self.__node_name_prefix = spec.get("node_name_prefix", "NODE")
# self.__n_nodes decides number of digits in slurm hostnames
@ -116,10 +117,16 @@ class ClusterSpec:
@property
def gpu_image(self) -> str:
"""Return the default image for containers of GPU workers."""
"""Return the default image for containers of GPU trainer workers."""
assert self.__loaded
return self.__gpu_image
@property
def gpu_infer_image(self) -> str:
"""Return the default image for containers of GPU inference workers."""
assert self.__loaded
return self.__gpu_infer_image
@property
def cpu_image(self) -> str:
"""Return the default image for containers of CPU workers."""

View File

@ -190,16 +190,12 @@ _global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer()
_fake_mp_world_size = None
_fake_mp_rank = None
# GLOBAL_STATS_TRACKER is used to track and log training stats that cannot be gracefully obtained via model outputs
# in interface implementations, e.g. load balancing loss in each MoE layer.
GLOBAL_STATS_TRACKER = defaultdict(dict)
GLOBAL_STATS_TRACKER_LOG_HOOKS = defaultdict(dict)
# TODO: As in Megatron, we can set NCCL group options. Is it necessary?
def reset_run():
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank, GLOBAL_STATS_TRACKER, GLOBAL_STATS_TRACKER_LOG_HOOKS
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank
_model_name = None
_grids = {}
_pgroups = {}
@ -209,8 +205,6 @@ def reset_run():
_global_memory_buffer = GlobalMemoryBuffer()
_fake_mp_world_size = None
_fake_mp_rank = None
GLOBAL_STATS_TRACKER = defaultdict(dict)
GLOBAL_STATS_TRACKER_LOG_HOOKS = defaultdict(dict)
@contextlib.contextmanager
@ -453,6 +447,10 @@ def pipe_parallel_group():
return grid().get_pipe_parallel_group()
def pipe_parallel_cpu_group():
return grid().pp_proc_group_gloo
def is_last_pipe_stage():
return pipe_parallel_rank() == pipe_parallel_world_size() - 1
@ -579,77 +577,3 @@ def get_env_vars(**kwargs):
"REAL_PACKAGE_PATH": str(get_repo_path()),
**BASE_ENVIRONS,
}
################# logging related #################
def save_to_global_stats_tracker(
key: str, value: Any, hook: Optional[Callable] = None, **hook_kwargs
):
"""Save kv-pair to global stats tracker for current model.
:param key: Key
:type key: str
:param value: Value
:type value: Any
:param hook: Hook function to be called before logging the stats in `log_global_stats_tracker`.
For example, this hook can be used to gather and average stats across parallel ranks.
:type hook: Optional[Callable]
:param hook_kwargs: Keyword arguments to be passed to the hook function.
"""
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
GLOBAL_STATS_TRACKER[_model_name][key] = value
if hook is not None:
GLOBAL_STATS_TRACKER_LOG_HOOKS[_model_name][key] = (hook, hook_kwargs)
def get_from_global_stats_tracker(key: str):
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
return GLOBAL_STATS_TRACKER[_model_name].get(key, None)
def clear_global_stats_tracker():
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
global GLOBAL_STATS_TRACKER
GLOBAL_STATS_TRACKER[_model_name] = dict()
def log_global_stats_tracker(
return_dict: bool = True, clear_stats_after_logging: bool = True
):
"""Log the global stats tracker and optionally return the stats as a
dictionary. This method is expected to be called in interface
implementations.
:param return_dict: Whether to return the stats as a dictionary.
:type return_dict: bool
:param clear_stats_after_logging: Whether to clear the stats after
logging.
:type clear_stats_after_logging: bool
"""
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
stats = GLOBAL_STATS_TRACKER[_model_name]
hooks = GLOBAL_STATS_TRACKER_LOG_HOOKS[_model_name]
for key in stats.keys():
hook, hook_kwargs = hooks.get(key, None)
if hook is not None:
hook(**hook_kwargs)
res = {}
if not return_dict:
logger.info(f"Logging global stats tracker:")
for key, value in stats.items():
res[key] = value
if not return_dict:
logger.info(f"{key}: {value}")
if clear_stats_after_logging:
clear_global_stats_tracker()
if return_dict:
return res

View File

@ -137,18 +137,28 @@ def isolate_cuda_device(
# logger.info(f"Rank {rank} discovers local peers with global ranks {local_peers}")
local_peer_index = local_peers.index(str(rank))
if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == len(local_peers):
local_gpu_id = list(map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))[
local_peer_index
]
elif len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1:
local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
n_local_peers = len(local_peers)
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
n_visible_devices = len(visible_devices)
if n_visible_devices == 0:
raise RuntimeError(
f"No visible cuda devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
if n_visible_devices == n_local_peers:
local_gpu_id = visible_devices[local_peer_index]
elif n_visible_devices == 1:
local_gpu_id = os.environ["CUDA_VISIBLE_DEVICES"]
elif n_visible_devices % n_local_peers == 0:
# A process occupies multiple GPUs, e.g., TP generation server
factor = n_visible_devices // n_local_peers
local_gpu_id = visible_devices[factor * local_peer_index]
else:
if not os.environ.get("REAL_MODE") == "LOCAL":
raise RuntimeError(
f"Unresolvable CUDA_VISIBLE_DEVICES {os.environ['CUDA_VISIBLE_DEVICES']} on host {network.gethostname()}, "
f"local peers (global ranks) {local_peers}, local peer index {local_peer_index}."
)
# In the local mode, all processes use GPUs in a round-robin manner
devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
local_gpu_id = int(devices[local_peer_index % len(devices)])

View File

@ -141,6 +141,15 @@ def getLogger(
return logging.getLogger(name)
def log_wandb_tensorboard(data, step=None, summary_writer=None):
import wandb
wandb.log(data, step=step)
if summary_writer is not None:
for key, val in data.items():
summary_writer.add_scalar(f"{key}", val, step)
if __name__ == "__main__":
# The following serves as a color visualization test.
# The available color names are black, red, green, yellow, blue, purple, cyan and white

View File

@ -87,7 +87,7 @@ class NameRecordRepository:
The values is retrievable by get_subtree() given that no other
entries use the name prefix.
"""
sub_name = name.rstrip("/") + "/" + str(uuid.uuid4())[:8]
sub_name = os.path.join(os.path.normpath(name), str(uuid.uuid4())[:8])
self.add(sub_name, value, **kwargs)
return sub_name
@ -194,6 +194,7 @@ class MemoryNameRecordRepository(NameRecordRepository):
def __init__(self, log_events=False):
self.__store = {}
self.__to_delete = set()
self.__log_events = log_events
def add(
@ -204,12 +205,16 @@ class MemoryNameRecordRepository(NameRecordRepository):
keepalive_ttl=None,
replace=False,
):
if not name:
raise ValueError(f"Invalid name: {name}")
name = os.path.normpath(name)
if self.__log_events:
print(f"NameResolve: add {name} {value}")
if name in self.__store and not replace:
raise NameEntryExistsError(f"K={name} V={self.__store[name]} V2={value}")
assert isinstance(value, str)
self.__store[name] = value
self.__store[name] = str(value)
if delete_on_exit:
self.__to_delete.add(name)
def touch(self, name, value, new_time_to_live):
raise NotImplementedError()
@ -219,21 +224,26 @@ class MemoryNameRecordRepository(NameRecordRepository):
print(f"NameResolve: delete {name}")
if name not in self.__store:
raise NameEntryNotFoundError(f"K={name}")
if name in self.__to_delete:
self.__to_delete.remove(name)
del self.__store[name]
def clear_subtree(self, name_root):
if self.__log_events:
print(f"NameResolve: clear_subtree {name_root}")
name_root = name_root.rstrip("/")
name_root = os.path.normpath(name_root)
for name in list(self.__store):
if (
name_root == "/"
or name == name_root
or name.startswith(name_root + "/")
):
if name in self.__to_delete:
self.__to_delete.remove(name)
del self.__store[name]
def get(self, name):
name = os.path.normpath(name)
if name not in self.__store:
raise NameEntryNotFoundError(f"K={name}")
r = self.__store[name]
@ -244,7 +254,7 @@ class MemoryNameRecordRepository(NameRecordRepository):
def get_subtree(self, name_root):
if self.__log_events:
print(f"NameResolve: get_subtree {name_root}")
name_root = name_root.rstrip("/")
name_root = os.path.normpath(name_root)
rs = []
for name, value in self.__store.items():
if (
@ -259,14 +269,20 @@ class MemoryNameRecordRepository(NameRecordRepository):
if self.__log_events:
print(f"NameResolve: find_subtree {name_root}")
rs = []
for name, value in self.__store.items():
if name.startswith(name_root):
for name in self.__store:
if (
name_root == "/"
or name == name_root
or name.startswith(name_root + "/")
):
rs.append(name)
rs.sort()
return rs
def reset(self):
self.__store = {}
for name in self.__to_delete:
self.__store.pop(name)
self.__to_delete = set()
class NfsNameRecordRepository(NameRecordRepository):
@ -294,8 +310,16 @@ class NfsNameRecordRepository(NameRecordRepository):
):
if not name:
raise ValueError("Name cannot be empty")
name = os.path.normpath(name)
path = self.__file_path(name)
os.makedirs(os.path.dirname(path), exist_ok=True)
while True:
# To avoid concurrency issues when multiple processes
# call makedirs on the same dirname of CPFS.
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
break
except (NotADirectoryError, FileNotFoundError):
pass
if os.path.isfile(path) and not replace:
raise NameEntryExistsError(path)
local_id = str(uuid.uuid4())[:8]
@ -329,6 +353,7 @@ class NfsNameRecordRepository(NameRecordRepository):
logger.info("No such name resolve path: %s", dir_path)
def get(self, name):
name = os.path.normpath(name)
path = self.__file_path(name)
if not os.path.isfile(path):
raise NameEntryNotFoundError(path)
@ -348,9 +373,15 @@ class NfsNameRecordRepository(NameRecordRepository):
dir_path = self.__dir_path(name_root)
rs = []
if os.path.isdir(dir_path):
for item in os.listdir(dir_path):
for root, _, files in os.walk(dir_path):
try:
rs.append(self.get(os.path.join(name_root, item)))
if len(files) != 1:
continue
if files[0] != "ENTRY":
continue
key = root.removeprefix(self.RECORD_ROOT)
key = key.removeprefix("/")
rs.append(self.get(key))
except NameEntryNotFoundError:
pass
return rs
@ -359,10 +390,15 @@ class NfsNameRecordRepository(NameRecordRepository):
dir_path = self.__dir_path(name_root)
rs = []
if os.path.isdir(dir_path):
for item in os.listdir(dir_path):
for root, _, files in os.walk(dir_path):
try:
self.get(os.path.join(name_root, item))
rs.append(os.path.join(name_root, item))
if len(files) != 1:
continue
if files[0] != "ENTRY":
continue
key = root.removeprefix(self.RECORD_ROOT)
key = key.removeprefix("/")
rs.append(key)
except NameEntryNotFoundError:
pass
rs.sort()
@ -374,7 +410,7 @@ class NfsNameRecordRepository(NameRecordRepository):
self.delete(name)
except:
pass
self.__to_delete = {}
self.__to_delete = set()
class RedisNameRecordRepository(NameRecordRepository):
@ -580,6 +616,8 @@ class Etcd3NameRecordRepository(NameRecordRepository):
)
self._keepalive_thread.start()
self._to_delete = set()
logger.info(f"Connected to etcd3 at {self._host}:{self._port}")
def __del__(self):
@ -623,7 +661,9 @@ class Etcd3NameRecordRepository(NameRecordRepository):
Raises:
NameEntryExistsError: If the key already exists and replace is False
"""
name = name.rstrip("/")
if not name:
raise ValueError(f"Invalid name: {name}")
name = os.path.normpath(name)
value = str(value)
with self._lock:
@ -641,9 +681,12 @@ class Etcd3NameRecordRepository(NameRecordRepository):
lease_id = self._create_lease(keepalive_ttl)
# Encode the string value to bytes
self._client.put(name, value.encode("utf-8"), lease=lease_id)
self._to_delete.add(name)
else:
# Encode the string value to bytes
self._client.put(name, value.encode("utf-8"))
if delete_on_exit:
self._to_delete.add(name)
# Store entry information for keepalive management
self._entries[name] = self._Entry(
@ -668,6 +711,8 @@ class Etcd3NameRecordRepository(NameRecordRepository):
"""
with self._lock:
self._delete_locked(name)
if name in self._to_delete:
self._to_delete.remove(name)
def _delete_locked(self, name):
"""Delete a key from etcd with lock already acquired.
@ -698,7 +743,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
"""
with self._lock:
count = 0
name_root = name_root.rstrip("/")
name_root = os.path.normpath(name_root)
# Get all keys with the prefix
for key_metadata_tuple in self._client.get_prefix(name_root):
key = key_metadata_tuple[1].key.decode(
@ -724,7 +769,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
"""
with self._lock:
rs = []
name_root = name_root.rstrip("/")
name_root = os.path.normpath(name_root)
for value_metadata_tuple in self._client.get_prefix(name_root):
value = value_metadata_tuple[0].decode("utf-8") # Extract the value
rs.append(value)
@ -760,6 +805,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
name = os.path.normpath(name)
with self._lock:
return self._get_locked(name)
@ -784,13 +830,14 @@ class Etcd3NameRecordRepository(NameRecordRepository):
"""Delete all keys added via this repository instance."""
with self._lock:
count = 0
for name in list(self._entries):
try:
self._delete_locked(name)
count += 1
except NameEntryNotFoundError:
pass
self._entries = {}
for name in self._to_delete:
if name in self._entries:
try:
self._delete_locked(name)
count += 1
except NameEntryNotFoundError:
pass
self._to_delete = set()
logger.info(f"Reset {count} saved etcd entries")
def _keepalive_thread_run(self):
@ -831,6 +878,17 @@ class Etcd3NameRecordRepository(NameRecordRepository):
if isinstance(names, str):
names = [names]
q = queue.Queue(maxsize=len(names))
for _ in range(len(names) - 1):
q.put(0)
def wrap_call_back():
try:
q.get_nowait()
except queue.Empty:
logger.info(f"Key {names} is gone. Executing callback {call_back}")
call_back()
# Use etcd's native watch capability for more efficient watching
for name in names:
# First wait for the key to exist
@ -838,7 +896,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
# Start watching for key deletion
watch_id = self._client.add_watch_callback(
name, lambda event: self._watch_callback(event, call_back)
name, lambda event: self._watch_callback(event, wrap_call_back)
)
# Store watch ID for cleanup

View File

@ -64,3 +64,34 @@ def distributed_master(experiment_name, trial_name, model_name):
def model_version(experiment_name, trial_name, model_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/model_version/{model_name}"
def metric_server_root(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/metrics"
def metric_server(experiment_name, trial_name, group, name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/metrics/{group}/{name}"
def push_pull_stream(experiment_name, trial_name, stream_name):
# Used to write addresses
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/push_pull_stream/{stream_name}"
def push_pull_stream_root(experiment_name, trial_name):
# Used to collect addresses
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/push_pull_stream/"
def stream_pullers(experiment_name, trial_name):
# Used to claim identities so that pushers know the number of pullers
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/push_pull_stream_peers/"
def gen_servers(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
def gen_server_manager(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"

View File

@ -2,6 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import hashlib
import os
import random
@ -10,10 +11,18 @@ import torch
import transformers
_SEED = None
_BASE_SEED = None
_SHUFFLER = None
def set_random_seed(seed):
global _SEED
def _seed_from_key(key: str) -> int:
return int(hashlib.sha256(key.encode()).hexdigest(), 16) & 0xFFFFFFFF
def set_random_seed(base_seed, key):
global _SEED, _BASE_SEED
_BASE_SEED = base_seed
seed = base_seed + _seed_from_key(key)
_SEED = seed
os.environ["PYTHONHASHSEED"] = str(seed)
transformers.set_seed(seed)
@ -27,4 +36,23 @@ def set_random_seed(seed):
def get_seed() -> int:
global _SEED
assert _SEED is not None
return _SEED
class Shuffler:
def __init__(self, key="default"):
self.cnt = 0
self.base_key = key
def next_shuffle(self) -> int:
shuffle_key = f"{self.base_key}_{self.cnt}"
self.cnt += 1
return _seed_from_key(shuffle_key)
def get_shuffle_seed() -> int:
global _BASE_SEED, _SHUFFLER
if _SHUFFLER is None:
_SHUFFLER = Shuffler(f"AReaL-seed{_BASE_SEED}")
return _SHUFFLER.next_shuffle()

View File

@ -0,0 +1,251 @@
from collections import defaultdict
from enum import Enum, auto
from typing import Dict
import torch
import torch.distributed as dist
class ReduceType(Enum):
AVG = auto()
SUM = auto()
MIN = auto()
MAX = auto()
SCALAR = auto()
MOE_AUX_LOSSES = {}
class DistributedStatsTracker:
def __init__(self, name: str = ""):
self.scope_stack = []
if name:
self.scope_stack.append(name.strip("/"))
self.denominators = {} # key -> denominator key
self.reduce_types = {} # key -> ReduceType
self.stats = defaultdict(list)
def scope(self, name):
"""Context manager for hierarchical scoping"""
return self.Scope(self, name)
class Scope:
def __init__(self, tracker, name):
self.tracker = tracker
self.name = name.strip("/")
def __enter__(self):
self.tracker.scope_stack.append(self.name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.tracker.scope_stack.pop()
def _get_full_key(self, key):
"""Combine scope stack with current key"""
if not self.scope_stack:
return key
return "/".join(self.scope_stack + [key])
def denominator(self, **kwargs):
for key, value in kwargs.items():
if not isinstance(value, torch.Tensor) or value.dtype != torch.bool:
raise ValueError(
f"`{key}` must be a pytorch bool tensor: {value.dtype}"
)
if value.numel() == 0:
raise ValueError(f"`{key}` must be non-empty")
full_key = self._get_full_key(key)
self._set_reduce_type(full_key, ReduceType.SUM)
self.stats[full_key].append(value.detach().clone())
def scalar(self, **kwargs):
for key, value in kwargs.items():
full_key = self._get_full_key(key)
self._set_reduce_type(full_key, ReduceType.SCALAR)
self.stats[full_key].append(float(value))
def stat(
self,
denominator: str,
reduce_type: ReduceType | None = None,
**kwargs,
):
"""Record multiple values from a dictionary"""
for key, value in kwargs.items():
if not isinstance(value, torch.Tensor) or value.dtype != torch.float:
raise ValueError(
f"`{key}` should be a pytorch float tensor: {value.dtype}"
)
if value.numel() == 0:
raise ValueError(f"`{key}` should be non-empty")
if reduce_type == ReduceType.SCALAR:
raise ValueError("Cannot use the scalar reduce type for a tensor")
full_key = self._get_full_key(key)
denorm = self._get_full_key(denominator)
if denorm not in self.stats or not self.stats[denorm]:
raise ValueError(f"Denominator `{denorm}` does not exist")
for x, y in zip(self.stats[denorm], self.stats[full_key] + [value]):
assert x.shape == y.shape, (x.shape, y.shape)
self.denominators[full_key] = denorm
if reduce_type is not None:
self._set_reduce_type(full_key, reduce_type)
self.stats[full_key].append(value.detach().clone())
def _set_reduce_type(self, key, reduce_type):
if not isinstance(reduce_type, ReduceType):
raise ValueError("reduce_type must be a ReduceType enum")
self.reduce_types[key] = reduce_type
def export(self, key=None, reduce_group=None, reset=True) -> Dict[str, float]:
"""Get aggregated statistics"""
self._amend_moe_losses()
if reduce_group is None:
try:
from realhf.base.constants import data_parallel_group
reduce_group = data_parallel_group()
except:
pass
if key is not None:
full_key = self._get_full_key(key)
result = self._aggregate(full_key, reduce_group)
if reset:
if full_key in self.denominators:
self.denominators.pop(full_key)
if full_key in self.reduce_types:
self.denominators.pop(full_key)
self.stats.pop(full_key)
return result
results = {}
for key in list(self.stats.keys()):
results.update(self._aggregate(key, reduce_group))
if reset:
self.denominators = {}
self.reduce_types = {}
self.stats = defaultdict(list)
results = {
k: v.cpu().item() if torch.is_tensor(v) else v for k, v in results.items()
}
return results
def _amend_moe_losses(self):
from realhf.base.constants import is_last_pipe_stage, pipe_parallel_group
global MOE_AUX_LOSSES
mean_losses = {}
for k, loss in MOE_AUX_LOSSES.items():
dist.all_reduce(loss, group=pipe_parallel_group())
mean_losses[k] = float(loss.mean()) # average over layers
MOE_AUX_LOSSES.clear()
if mean_losses and is_last_pipe_stage():
self.scalar(**mean_losses)
def _aggregate(self, key, reduce_group):
if key not in self.stats or not self.stats[key]:
return {}
reduce_type = self.reduce_types.get(key, None)
result = {}
if reduce_type is None:
result["/".join([key, "avg"])] = self._avg_of(key, reduce_group)
result["/".join([key, "min"])] = self._min_of(key, reduce_group)
result["/".join([key, "max"])] = self._max_of(key, reduce_group)
elif reduce_type == ReduceType.AVG:
result[key] = self._avg_of(key, reduce_group)
elif reduce_type == ReduceType.SUM:
result[key] = self._sum_of(key, reduce_group)
elif reduce_type == ReduceType.MIN:
result[key] = self._min_of(key, reduce_group)
elif reduce_type == ReduceType.MAX:
result[key] = self._max_of(key, reduce_group)
elif reduce_type == ReduceType.SCALAR:
result[key] = sum(self.stats[key]) / len(self.stats[key])
else:
raise ValueError(f"Unknown reduce type: {reduce_type}")
return result
def _sum_of(self, key, reduce_group):
values = self.stats[key]
if key not in self.denominators:
x = sum([x.sum() for x in values])
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
else:
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(
f"Denominator `{denominator}` not set for key `{key}`."
)
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, 0.0).sum())
x = sum(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
return float(x)
def _avg_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
ds = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, 0.0).sum())
ds.append(d.sum())
x = sum(xs)
d = sum(ds)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group)
dist.all_reduce(d, group=reduce_group)
if d == 0:
return 0
return x / d
def _min_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, float("inf")).min())
x = min(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
if torch.isinf(x):
return float("nan")
return float(x)
def _max_of(self, key, reduce_group):
values = self.stats[key]
denominator = self.denominators[key]
if denominator not in self.stats:
raise ValueError(f"Denominator `{denominator}` not set for key `{key}`.")
xs = []
for v, d in zip(values, self.stats[denominator]):
xs.append(torch.where(d, v, -float("inf")).max())
x = max(xs)
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
if torch.isinf(x):
return float("nan")
return float(x)
DEFAULT_TRACKER = DistributedStatsTracker()
stat = DEFAULT_TRACKER.stat
denominator = DEFAULT_TRACKER.denominator
export = DEFAULT_TRACKER.export
scope = DEFAULT_TRACKER.scope
scalar = DEFAULT_TRACKER.scalar

View File

@ -0,0 +1,91 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import copy
import dataclasses
from typing import Any, Dict, List, Tuple
import realhf.base.logging as logging
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
from realhf.api.core.config import AgentAbstraction, EnvServiceAbstraction
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
from realhf.experiments.common.utils import asdict
logger = logging.getLogger("Async PPO Math exp", "colored")
@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction(
"math-single-step",
args=dict(
gconfig=self.generation_config,
tokenizer_path=self.actor.path,
success_rate_lb=self.success_rate_lb,
success_rate_ub=self.success_rate_ub,
reward_scaling=self.ppo.reward_output_scaling,
reward_bias=self.ppo.reward_output_bias,
),
)
@property
def env(self) -> EnvServiceAbstraction:
return EnvServiceAbstraction(
"math-single-step", args=dict(dataset_path=self.dataset.path)
)
@property
def gen_backend_args(self) -> Any:
return self.actor.sglang
@property
def generation_config(self) -> GenerationHyperparameters:
return GenerationHyperparameters(**asdict(self.ppo.gen)).new(n=self.group_size)
@property
def rpcs(self):
rpcs = super(AsyncPPOMATHConfig, self).rpcs
rpcs["actor_gen"].output_keys = (
*rpcs["actor_gen"].output_keys,
"packed_prompts",
"version_start",
"version_end",
"rewards",
)
rpcs["actor_train"].input_keys = (
*rpcs["actor_train"].input_keys,
"version_start",
"version_end",
)
# Revert the effect of fuse_rew_ref, because we don't have the reward RPC in async experiments.
if "ref_inf" in rpcs:
actor_interface = rpcs["actor_train"].interface_impl
rpcs["ref_inf"].interface_impl = copy.deepcopy(actor_interface)
rpcs["ref_inf"].interface_impl.args["enable_save"] = False
rpcs["ref_inf"].input_keys = ("packed_input_ids",)
rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",)
if "rew_inf" in rpcs:
rpcs.pop("rew_inf")
return rpcs
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
models = super().models
if "reward" in models:
models.pop("reward")
return models
@property
def allocations(self):
allocations = super(AsyncPPOMATHConfig, self).allocations
if "rew_inf" in allocations:
allocations.pop("rew_inf")
return allocations
register_quickstart_exp("async-ppo-math", AsyncPPOMATHConfig)

View File

@ -0,0 +1,348 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import itertools
import os
from collections import defaultdict
from typing import *
import transformers
import realhf.base.logging as logging
from realhf.api.cli_args import AsyncRLOptions, ParallelismConfig
from realhf.api.core.config import (
AgentAbstraction,
DatasetAbstraction,
EnvServiceAbstraction,
ModelAbstraction,
ModelBackendAbstraction,
ModelName,
ModelShardID,
StandaloneModelShardAbstraction,
)
from realhf.api.core.dfg import ModelInterfaceType
from realhf.api.core.model_api import (
HF_MODEL_FAMILY_REGISTRY,
GenerationHyperparameters,
)
from realhf.api.core.system_api import (
ExperimentConfig,
ExperimentScheduling,
GenerationServer,
GserverManager,
ModelWorker,
RolloutWorker,
Scheduling,
TasksGroup,
)
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.base.cluster import spec as cluster_spec
from realhf.experiments.common.check import check_valid_sglang, check_valid_vllm
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import (
AllocationMode,
asdict,
get_real_model_config,
get_topo,
make_inf_backend_config,
make_train_backend_config,
resolve_replica_ids,
resolve_rpc_hooks,
)
logger = logging.getLogger("AsyncRLExperimentConfig", "colored")
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
GEN_WORKER_DEFAULT_CAPACITY = 512
@dataclasses.dataclass
class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
@property
def generation_config(self) -> GenerationHyperparameters:
raise NotImplementedError()
@property
def env(self) -> EnvServiceAbstraction:
return EnvServiceAbstraction("null")
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction("null")
@property
def gen_backend_args(self) -> Any:
raise NotImplementedError()
@property
def get_backend_type(self) -> str:
return "sglang"
def scheduling_setup(self) -> ExperimentScheduling:
"""The resourced occupied by each worker.
The resource requirements will be sent to SLURM or Ray, while
being ignored in the local mode.
"""
gen_world_size = AllocationMode.from_str(self.allocation_mode).get_gen_size()
train_world_size = self.n_nodes * self.n_gpus_per_node - gen_world_size
gen_tp_size = AllocationMode.from_str(self.allocation_mode).get_gen_tp_size()
return ExperimentScheduling(
master_worker=TasksGroup(
count=1,
scheduling=Scheduling.master_worker_default(
cpu=self.cpus_per_master_worker,
mem=self.mem_per_master_worker,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
model_worker=TasksGroup(
count=train_world_size,
scheduling=Scheduling.model_worker_default(
cpu=self.cpus_per_model_worker,
gpu=1,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_model_worker,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
generation_server=TasksGroup(
count=gen_world_size // gen_tp_size,
scheduling=Scheduling.generation_server_default(
cpu=self.cpus_per_generation_server,
gpu=gen_tp_size,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_generation_server,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
gserver_manager=TasksGroup(
count=1,
scheduling=Scheduling.gserver_manager_default(
cpu=self.cpus_per_gserver_manager,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_gserver_manager,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
rollout_worker=TasksGroup(
count=self.n_rollout_workers or train_world_size,
scheduling=Scheduling.rollout_worker_default(
cpu=self.cpus_per_rollout_worker,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_rollout_worker,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
)
def _get_model_worker_configs(
self, rpc_allocs: List[RPCAllocation]
) -> List[ModelWorker]:
self._run_model_sanity_check(rpc_allocs)
model_worker = []
shard_counter = defaultdict(lambda: 0)
model_name_to_rpc_allocs: Dict[ModelName, List[RPCAllocation]] = defaultdict(
list
)
for rpc_alloc in rpc_allocs:
model_name_to_rpc_allocs[rpc_alloc.rpc.model_name].append(rpc_alloc)
for i, j in itertools.product(range(self.n_nodes), range(self.n_gpus_per_node)):
if self.gen_device_mesh.mapping[i, j]:
continue
mw = ModelWorker(
base_seed=self.seed,
shards=[],
# NOTE: here we use puller stream to wrap the original dataset
datasets=[
DatasetAbstraction(
"puller_stream", args=dict(dataset_cfgs=self.datasets)
)
],
torch_cache_mysophobia=self.torch_cache_mysophobia,
cuda_cache_cleanliness=self.cache_clear_freq is not None,
cuda_cache_clear_freq=self.cache_clear_freq,
tokenizer_name_or_path=self.tokenizer_name_or_path,
)
for (
model_name,
model_rpc_allocs,
) in model_name_to_rpc_allocs.items():
rpcs = [rpc_alloc.rpc for rpc_alloc in model_rpc_allocs]
if self._allocation_mode.is_decoupled() and all(
rpc.is_generate() for rpc in rpcs
):
continue
rpc_alloc = model_rpc_allocs[0]
model_cfg = self.models[model_name.role]
model = get_real_model_config(
model_path=model_cfg.path,
hf_model_family=model_cfg.type._class,
is_critic=model_cfg.type.is_critic,
init_from_scratch=model_cfg.init_from_scratch,
init_critic_from_actor=model_cfg.init_critic_from_actor,
dtype="bf16" if model_cfg.bf16 else "fp16",
)
hf_config = transformers.AutoConfig.from_pretrained(
model_cfg.path,
trust_remote_code=True,
force_download=True,
)
model_config = HF_MODEL_FAMILY_REGISTRY[model_cfg.type._class][
"config_from_hf_converter"
](hf_config)
if (
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size
!= 0
) or (
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0
):
raise ValueError(
f"The number of KV heads {model_config.n_kv_heads} or "
f"Q heads {model_config.n_q_heads} is not"
f" divisible by the configured TP size "
f"({rpc_alloc.parallel.model_parallel_size}). "
f"Please decrease TP size."
)
mapping = rpc_alloc.device_mesh.mapping
gradient_checkpointing = model_cfg.gradient_checkpointing and any(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP for rpc in rpcs
)
topo = get_topo(
rpc_alloc.parallel,
gradient_checkpointing=gradient_checkpointing,
max_prompt_len=(
self.max_prompt_len
if any(
rpc.interface_type == ModelInterfaceType.GENERATE
for rpc in rpcs
)
else None
),
gradient_accumulation_fusion=(model_cfg.backend == "megatron")
and (model_cfg.type._class != "bailing"),
is_train=any(rpc.is_train() for rpc in rpcs),
)
if any(rpc.is_train() for rpc in rpcs):
backend = make_train_backend_config(model_cfg, rpc_alloc.parallel)
else:
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
if mapping[i, j]:
shard_idx = shard_counter[model_name]
mw.shards.append(
StandaloneModelShardAbstraction(
id=ModelShardID(
model_name=model_name,
topo=topo,
dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model,
),
model=model,
backend=backend,
eval_dataset=self.eval_dataset,
eval_bs=self.eval_bs,
)
)
shard_counter[model_name] += 1
model_worker.append(mw)
return model_worker
def get_rollout_worker_configs(self, rpc_allocs):
gen_world_size = AllocationMode.from_str(self.allocation_mode).get_gen_size()
train_world_size = self.n_nodes * self.n_gpus_per_node - gen_world_size
gen_rpc_alloc = next(alloc for alloc in rpc_allocs if alloc.rpc.is_generate())
model_name = gen_rpc_alloc.rpc.model_name
return [
RolloutWorker(
base_seed=self.seed,
model_name=model_name,
tokenizer_path=self.tokenizer_name_or_path,
new_tokens_per_chunk=self.new_tokens_per_chunk,
env=self.env,
agent=self.agent,
datasets=self.datasets,
rollout_request_timeout=self.flush_request_timeout,
)
for _ in range(self.n_rollout_workers or train_world_size)
]
def get_generation_server_configs(self, rpc_allocs):
am = AllocationMode.from_str(self.allocation_mode)
gen_world_size = am.get_gen_size()
gen_tp_size = am.get_gen_tp_size()
gen_rpc_alloc = next(alloc for alloc in rpc_allocs if alloc.rpc.is_generate())
model_name = gen_rpc_alloc.rpc.model_name
model_cfg = self.models[model_name.role]
return [
GenerationServer(
base_seed=self.seed,
backend_type=self.get_backend_type,
backend_args=self.gen_backend_args,
model_path=model_cfg.path,
tp_size=gen_tp_size,
)
for _ in range(gen_world_size // gen_tp_size)
]
def get_gserver_manager_config(self, rpc_allocs):
am = AllocationMode.from_str(self.allocation_mode)
gen_world_size = am.get_gen_size()
gen_tp_size = am.get_gen_tp_size()
gen_rpc_alloc = next(alloc for alloc in rpc_allocs if alloc.rpc.is_generate())
model_name = gen_rpc_alloc.rpc.model_name
train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()]
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
return [
GserverManager(
model_name=model_name,
flush_request_timeout=self.flush_request_timeout,
n_servers=gen_world_size // gen_tp_size,
schedule_policy="round_robin",
max_head_offpolicyness=self.max_head_offpolicyness,
train_batch_size=train_rpcs[0].n_seqs,
max_concurrent_rollouts=self.max_concurrent_rollouts,
)
]
def initial_setup(self) -> ExperimentConfig:
assert self._allocation_mode.is_decoupled(), self._allocation_mode
rpc_allocs = self._get_rpc_allocations()
resolve_replica_ids(rpc_allocs, self.models)
resolve_rpc_hooks(
rpc_allocs, self.models
) # inplace modify MFCDefs in rpc allocations
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
tensorboard=self.tensorboard,
# NOTE: master and model worker only see RPCs without generation
model_rpcs=[
rpc_alloc.rpc
for rpc_alloc in rpc_allocs
if not rpc_alloc.rpc.is_generate()
],
model_worker=self._get_model_worker_configs(rpc_allocs),
generation_server=self.get_generation_server_configs(rpc_allocs),
gserver_manager=self.get_gserver_manager_config(rpc_allocs),
rollout_worker=self.get_rollout_worker_configs(rpc_allocs),
auto_eval=self.auto_eval,
evaluator=self.auto_eval_config,
)

View File

@ -110,7 +110,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
Should be implemented in all subclasses.
"""
return NotImplementedError(f"datasets is not implemented in {self.__class__}")
raise NotImplementedError(f"datasets is not implemented in {self.__class__}")
@property
def eval_dataset(self) -> DatasetAbstraction | None:
@ -154,8 +154,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
n_nodes=self.n_nodes,
n_gpus_per_node=self.n_gpus_per_node,
mapping=np.ones((self.n_nodes, self.n_gpus_per_node), dtype=np.int32),
global_mesh_name=self.nodelist,
name=self.nodelist,
)
def _heuristic_rpc_allocation(self) -> List[RPCAllocation]:
@ -190,6 +188,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
cpu=self.cpus_per_master_worker,
mem=self.mem_per_master_worker,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
model_worker=TasksGroup(
@ -200,6 +199,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_model_worker,
nodelist=self.nodelist,
exclude=self.exclude,
),
),
)
@ -351,12 +351,16 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
)
rpc_allocs.append(alloc)
elif self.allocation_mode == "manual":
if self.nodelist is None:
raise ValueError(
"The 'nodelist' option must be specified when using manual allocation mode."
)
rpc_allocs: List[RPCAllocation] = [
RPCAllocation(
rpc=rpc,
device_mesh=(
make_device_mesh_from_name(
self.global_device_mesh.name,
self.nodelist,
self.allocations[rpc_type].device_mesh,
self.global_device_mesh.n_gpus_per_node,
)
@ -394,6 +398,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
base_seed=self.seed,
shards=[],
datasets=self.datasets,
shuffle_dataset=self.shuffle_dataset,
torch_cache_mysophobia=self.torch_cache_mysophobia,
cuda_cache_cleanliness=self.cache_clear_freq is not None,
cuda_cache_clear_freq=self.cache_clear_freq,

View File

@ -0,0 +1,123 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
from typing import Dict
import realhf.base.logging as logging
from realhf.api.cli_args import MathCodeEvalOptions, ModelTrainEvalConfig
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
)
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import asdict
logger = logging.getLogger("Math Cdoe Eval exp", "colored")
@dataclasses.dataclass
class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
return {
"actor": self.actor,
"reward": self.rew,
}
@property
def rpcs(self):
if (
self.dataset.max_prompt_len + self.gen_config.max_new_tokens
> self.actor.vllm.max_seq_len_to_capture
):
raise RuntimeError(
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.gen_config.max_new_tokens}"
)
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
"generation_config": asdict(self.gen_config),
"group_size": self.group_size,
},
)
rw_interface = ModelInterfaceAbstraction(
"rw-math-code",
args=dict(
dataset_path=self.dataset.path,
tokenizer_path=self.actor.path,
rw_type=self.rw_type,
check_xml_format=self.check_xml_format,
group_size=self.group_size,
check_verifier_status=self.check_verifier_status,
),
)
rollout = MFCDef(
name="actor_gen",
model_name="actor",
mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=("packed_prompts", "task_ids"),
output_keys=("packed_input_ids",),
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
output_keys=("rewards",),
n_seqs=self.dataset.train_bs_n_seqs,
)
return {
"actor_gen": rollout,
"rew_inf": inf_reward,
}
@property
def allocations(self):
return {
"actor_gen": self.actor_gen,
"rew_inf": self.rew_inf,
}
@property
def datasets(self):
return [
DatasetAbstraction(
"math_code_prompt",
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,
),
)
]
@property
def tokenizer_name_or_path(self) -> str:
return self.actor.path
@property
def max_prompt_len(self):
return self.dataset.max_prompt_len
register_quickstart_exp("math-code-eval", MathCodeEvalConfig)

View File

@ -4,7 +4,6 @@
import copy
import dataclasses
import os
import pprint
from typing import Dict
import realhf.base.logging as logging
@ -29,7 +28,6 @@ logger = logging.getLogger("PPO Math exp", "colored")
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
@property
def ppo_kwargs(self):
return dict(
@ -61,7 +59,9 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
}
if self.ppo.disable_value:
models.pop("critic")
if self.ppo.fuse_rew_ref:
if self.ppo.kl_ctl == 0:
models.pop("ref")
if self.ppo.fuse_rew_ref and self.ppo.kl_ctl != 0:
models.pop("reward")
return models
@ -80,7 +80,11 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if domain and (not (domain.startswith("http://") and ":" in domain)):
if (
domain
and (not (domain.startswith("http://") and ":" in domain))
and (not (domain.startswith("https://") and ":" in domain))
):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
@ -101,6 +105,8 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
"generation_size": self.generation_size,
"group_adv_norm": self.group_adv_norm,
"mask_too_long": self.mask_too_long,
"sample_reuse": self.ppo.actor_sample_reuse,
"c_clip": self.ppo.c_clip,
},
)
@ -110,6 +116,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
**copy.deepcopy(self.ppo_kwargs),
"group_size": self.group_size,
"mask_too_long": self.mask_too_long,
"sample_reuse": self.ppo.critic_sample_reuse,
},
)
critic_interface.args.pop("eps_clip")
@ -222,12 +229,15 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"task_ids",
"values",
"prompt_mask",
"seq_no_eos_mask",
]
if self.ppo.disable_value:
train_actor_inputs.remove("values")
if self.ppo.kl_ctl == 0:
train_actor_inputs.remove("packed_ref_logprobs")
train_actor = MFCDef(
name="actor_train",
model_name="actor",
@ -242,6 +252,17 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
n_seqs=self.dataset.train_bs_n_seqs,
)
train_critic_inputs = [
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
]
if self.ppo.kl_ctl == 0:
train_critic_inputs.remove("packed_ref_logprobs")
train_critic = MFCDef(
name="critic_train",
model_name="critic",
@ -250,15 +271,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=(
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
),
input_keys=tuple(train_critic_inputs),
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
@ -278,7 +291,9 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
rpcs.pop("critic_train")
if not self.ppo.recompute_logprob:
rpcs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
if self.ppo.kl_ctl == 0:
rpcs.pop("ref_inf")
if self.ppo.fuse_rew_ref and self.ppo.kl_ctl != 0:
rpcs.pop("rew_inf")
return rpcs
@ -298,7 +313,9 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
allocs.pop("critic_train")
if not self.ppo.recompute_logprob:
allocs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
if self.ppo.kl_ctl == 0:
allocs.pop("ref_inf")
if self.ppo.fuse_rew_ref and self.ppo.kl_ctl != 0:
allocs.pop("rew_inf")
return allocs
@ -340,8 +357,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
rpc_allocs, self.models
) # inplace modify MFCDefs in rpc allocations
pprint.pprint(rpc_allocs)
######### update ref model using ema, ref_ema_eta = 0 means fixed ref model #########
def _find_rpc(name):
return next(alloc.rpc for alloc in rpc_allocs if alloc.rpc.name == name)

View File

@ -255,7 +255,7 @@ class AllocationType(enum.Enum):
@dataclasses.dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: Dict[str, Dict[str, int]]
parallel_strat: None | Dict[str, Dict[str, int]]
def is_decoupled(self):
return self.type_ in [
@ -276,6 +276,17 @@ class AllocationMode:
def is_global_hybrid(self):
return self.type_ == AllocationType.GLOBAL_HYBRID
def get_gen_size(self):
assert self.is_decoupled()
paras = self.parallel_strat
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
return gdp * gpp * gmp
def get_gen_tp_size(self):
assert self.is_decoupled()
paras = self.parallel_strat
return paras["gen"]["m"]
@classmethod
def from_str(cls, allocation_mode: str):
if allocation_mode == "manual":

View File

@ -0,0 +1,2 @@
import realhf.impl.agent.math_single_step_agent
import realhf.impl.agent.null_agent

View File

@ -0,0 +1,240 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import json
import os
from typing import List
import colorama
import numpy as np
import torch
from realhf.api.core.agent_api import Agent, register_agent
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.api.core.env_api import EnvironmentService
from realhf.api.core.model_api import BundledGenerationOutputs
from realhf.base import constants, logging
logger = logging.getLogger("Math Code Agent")
class MathSingleStepAgent(Agent):
def __init__(
self,
gconfig,
tokenizer_path,
success_rate_lb,
success_rate_ub,
reward_scaling=1.0,
reward_bias=0.0,
):
self.gconfig = gconfig
self.tokenizer = load_hf_tokenizer(tokenizer_path)
self.success_rate_lb = success_rate_lb
self.success_rate_ub = success_rate_ub
self.reward_scaling = reward_scaling
self.reward_bias = reward_bias
async def collect_trajectory(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
) -> List[SequenceSample]:
# reset does nothing, just to make it like multi-step environments
await env.reset()
assert prompt.bs == 1
prompt_token_ids = prompt.data["packed_prompts"].cpu().numpy().tolist()
qid = prompt.ids[0]
await obs_queue.put((qid, prompt_token_ids, self.gconfig))
act: BundledGenerationOutputs = await act_queue.get()
seq_strs = self.tokenizer.batch_decode(
act.seqs,
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)
prompt_str = self.tokenizer.batch_decode(
[act.prompt_ids],
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)[0]
answers = [seq_str.split(prompt_str)[1] for seq_str in seq_strs]
# single-step env
_, success, *_ = await env.step((qid, answers))
rewards = [
((float(r) - 0.5) * 2 - self.reward_bias) * self.reward_scaling
for r in success
]
self.log_rewards_to_file(
str(qid),
prompt_str,
seqlens=[len(s) for s in act.seqs],
answers=answers,
prompt_len=len(prompt_token_ids),
rewards=rewards,
success=success,
version_starts=act.version_start,
version_ends=act.version_end,
)
r = np.mean([float(s) for s in success])
if r < self.success_rate_lb:
logger.info(f"Query ID {qid} reward too low: {r} < {self.success_rate_lb}.")
return []
if r > self.success_rate_ub:
logger.info(
f"Query ID {qid} reward too high: {r} > {self.success_rate_ub}."
)
return []
x = SequenceSample(
keys=[
"packed_input_ids",
"prompt_mask",
"packed_logprobs",
"seq_no_eos_mask",
"packed_prompts",
"version_start",
"version_end",
"rewards",
"task_ids",
],
ids=[qid],
dtypes=dict(
packed_prompts=torch.long,
packed_input_ids=torch.long,
prompt_mask=torch.bool,
seq_no_eos_mask=torch.bool,
version_start=torch.int,
version_end=torch.int,
packed_logprobs=torch.float32,
rewards=torch.float32,
task_ids=torch.long,
),
trailing_shapes=dict(
packed_input_ids=(),
prompt_mask=(),
seq_no_eos_mask=(),
packed_prompts=(),
version_end=(),
version_start=(),
packed_logprobs=(),
rewards=(),
task_ids=(),
),
seqlens=dict(
packed_input_ids=[act.seqlens],
packed_logprobs=[[s - 1 for s in act.seqlens]],
packed_prompts=[[act.prompt_len]],
prompt_mask=[act.seqlens],
seq_no_eos_mask=[[1 for _ in range(self.gconfig.n)]],
rewards=[[1 for _ in range(self.gconfig.n)]],
version_start=[[1 for _ in range(self.gconfig.n)]],
version_end=[[1 for _ in range(self.gconfig.n)]],
task_ids=[[1]],
),
data=dict(
packed_prompts=torch.tensor(act.prompt_ids, dtype=torch.long),
packed_logprobs=torch.tensor(
sum(act.logprobs, []), dtype=torch.float32
),
packed_input_ids=torch.tensor(sum(act.seqs, []), dtype=torch.long),
seq_no_eos_mask=torch.tensor(act.no_eos, dtype=torch.bool),
rewards=torch.tensor(rewards, dtype=torch.float32),
version_start=torch.tensor(act.version_start, dtype=torch.int),
version_end=torch.tensor(act.version_end, dtype=torch.int),
prompt_mask=torch.tensor(
sum(
[
[1] * act.prompt_len + [0] * (seqlen - act.prompt_len)
for seqlen in act.seqlens
],
[],
),
dtype=torch.bool,
),
task_ids=prompt.data["task_ids"],
),
)
return [x]
def log_rewards_to_file(
self,
qid: str,
prompt: str,
prompt_len: int,
answers: List[str],
seqlens: List[int],
rewards: List[float],
success: List[bool],
version_starts: List[int],
version_ends: List[int],
):
group_size = len(answers)
for group_idx in range(group_size):
# NOTE: we can ensure that only one process is logging this query id
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
str(version_starts[group_idx]),
f"{qid}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
version_start = version_starts[group_idx]
version_end = version_ends[group_idx]
reward = rewards[group_idx]
answer = answers[group_idx]
seqlen = seqlens[group_idx]
with open(gen_file_path, "a") as _f:
info = "\n".join(
[
f"idx: {group_idx + 1} / {group_size}, seqlen: {seqlen}, "
f"head version: {version_start}, tail version: {version_end}.",
f"reward is {reward}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{answer}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
str(version_starts[group_idx]),
f"{qid}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
with open(train_pass_monitor_file_path, "a") as monitor_file:
monitor_file.write(
json.dumps(
{
"version_start": int(version_start),
"version_end": int(version_end),
"success": bool(success),
"prompt_len": prompt_len,
"answer_len": seqlen - prompt_len,
},
ensure_ascii=False,
)
+ "\n"
)
register_agent("math-single-step", MathSingleStepAgent)

View File

@ -0,0 +1,57 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
# A null agent used for testing
import asyncio
import copy
import random
from typing import List
from realhf.api.core.agent_api import Agent, register_agent
from realhf.api.core.data_api import SequenceSample
from realhf.api.core.env_api import EnvironmentService
from realhf.api.core.model_api import (
BundledGenerationOutputs,
GenerationHyperparameters,
)
from realhf.base import logging, testing
logger = logging.getLogger("Null Agent")
class NullAgent(Agent):
OBS_PUT_CNT = 0
ACT_GET_CNT = 0
def __init__(self, episode_length: int = 1, traj_size: int = 1):
self.episode_len = episode_length
self.traj_size = traj_size
async def collect_trajectory(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
) -> List[SequenceSample]:
qid = prompt.ids[0]
prompt_token_ids = [
random.randint(0, testing.TESTING_MODEL_VOCAB_SIZE - 1)
for _ in range(random.randint(0, 64))
]
for step in range(self.episode_len):
await obs_queue.put((qid, prompt_token_ids, GenerationHyperparameters()))
self.OBS_PUT_CNT += 1
act = await act_queue.get()
self.ACT_GET_CNT += 1
assert isinstance(act, BundledGenerationOutputs)
ids = [str(qid) + f"-{idx}" for idx in range(self.traj_size)]
traj = [copy.deepcopy(prompt) for _ in range(self.traj_size)]
for t, i in zip(traj, ids):
t.ids[0] = i
return traj
register_agent("null", NullAgent)

View File

@ -3,6 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import json
import sys
import traceback
from collections import defaultdict
from typing import Callable, Dict, Hashable, List, Optional
@ -17,7 +18,7 @@ logger = logging.getLogger("Math Code Dataset")
def check_math_metadata_entries(data):
assert data["task"] == "math"
assert data["task"] == "math" or data["task"] == "stem"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
assert isinstance(data["prompt"], str)
@ -34,6 +35,10 @@ def check_code_metadata_entries(data):
if "problem_id" not in data:
data["problem_id"] = data["query_id"]
assert isinstance(data["prompt"], str)
case_size = sys.getsizeof(data["input_output"])
assert (
case_size < 500 * 1024
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
input_output = json.loads(data["input_output"])
assert len(input_output["inputs"]) == len(input_output["outputs"])
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
@ -61,7 +66,7 @@ def load_metadata(path):
logger.warning(
f'Key "task" not found in the dataset. Use math as default task type.'
)
if d["task"] == "math":
if d["task"] == "math" or d["task"] == "stem":
d = check_math_metadata_entries(d)
elif d["task"] == "code":
d = check_code_metadata_entries(d)
@ -79,7 +84,6 @@ def load_metadata(path):
class MATHCodePromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,

View File

@ -54,7 +54,7 @@ def parse_line(id2info, prompt_str, generated, query_id):
f.write(json.dumps({"answer": generated, "solution": cur_solution}) + "\n")
venv_python = "/sympy/bin/python3"
logger.info(f"math verify working dir: `{os.getcwd()}`")
# logger.info(f"math verify working dir: `{os.getcwd()}`")
pro = subprocess.Popen(
" ".join(
[
@ -131,7 +131,7 @@ def parse_lines_in_parallel(
all_query_indices.append(query_indices)
venv_python = "/sympy/bin/python3"
logger.info(f"math verify working dir: `{os.getcwd()}`")
# logger.info(f"math verify working dir: `{os.getcwd()}`")
procs = []
for tmp_id in tmp_ids:
pro = subprocess.Popen(

View File

@ -0,0 +1 @@
import realhf.impl.environment.math_single_step_env

View File

@ -0,0 +1,38 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import os
from typing import List, Tuple
from functioncall.math.verify import math_verify
from realhf.api.core.env_api import EnvironmentService, register_environment
from realhf.base import logging
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
logger = logging.getLogger("Math Single Step Environment")
class MathSingleStepEnv(EnvironmentService):
def __init__(self, dataset_path: str):
self.id2info, _ = load_metadata(dataset_path)
async def reset(self, seed=None, options=None):
return None, {}
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
group_size = len(answers)
format_rewards = await asyncio.to_thread(
math_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
return None, format_rewards, True, False, {}
register_environment("math-single-step", MathSingleStepEnv)

View File

@ -357,6 +357,11 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
if update_successful:
incr = version_steps - self.engine.lr_scheduler.num_steps
self.engine.lr_scheduler.step(incr)
grad_norm = torch.tensor(
grad_norm, device=constants.current_device(), dtype=torch.float32
)
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
logger.info(
f"Model name {constants.model_name()}, "
@ -366,7 +371,15 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
f"Current loss scale: {self.engine.optim.get_loss_scale()}. "
f"Learning rate: {[param_group['lr'] for param_group in self.engine.optim.param_groups]}. "
)
return update_successful, grad_norm, num_zeros_in_grad
stat = dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
loss_scale=float(self.engine.optim.get_loss_scale()),
)
for i, param_group in enumerate(self.engine.optim.param_groups):
stat[f"param_group{i}/lr"] = param_group["lr"]
# NOTE: we only have one optimizer step for each stage, so micro_batch_id can be 0
tensor_buffer.put("stats", 0, stat)
class ReaLMegatronEngine(model_api.PipelinableEngine):
@ -448,7 +461,6 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
)
no_sync_ctx = self.engine.ddp.no_sync()
no_sync_ctx.__enter__()
stat = collections.defaultdict(int)
for i, mb_input in enumerate(mb_inputs):
if i == len(mb_inputs) - 1:
no_sync_ctx.__exit__(None, None, None)
@ -464,7 +476,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).logits
loss, _stat = loss_fn(model_output, mb_input)
loss = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
# Megatron will average gradients across DP ranks.
@ -476,13 +488,9 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
loss *= loss_scale
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
loss.backward()
for k, v in _stat.items():
stat[k] += v
self.engine.finalize_grads()
self._step(version_steps)
return stat
return self._step(version_steps)
@torch.no_grad()
def forward(
@ -521,10 +529,16 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
# wrapper for profiler
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
def _step(self, version_steps):
# omit the number of zeros in grads
update_successful, grad_norm, _ = self.engine.optim.step()
if update_successful:
incr = version_steps - self.engine.lr_scheduler.num_steps
self.engine.lr_scheduler.step(incr)
grad_norm = torch.tensor(
grad_norm, device=constants.current_device(), dtype=torch.float32
)
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
logger.info(
f"Megatron backend update success? {update_successful}. "
@ -532,6 +546,14 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
f"Current loss scale: {self.engine.optim.get_loss_scale()}. "
f"Learning rate: {[param_group['lr'] for param_group in self.engine.optim.param_groups]}. "
)
stat = dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
loss_scale=float(self.engine.optim.get_loss_scale()),
)
for i, param_group in enumerate(self.engine.optim.param_groups):
stat[f"param_group{i}/lr"] = param_group["lr"]
return stat
@dataclasses.dataclass

View File

@ -1,9 +1,7 @@
# Copyright 2025 Ant Group Inc.
import collections
import dataclasses
import math
from contextlib import contextmanager
import random
from typing import *
import torch
@ -79,6 +77,8 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet):
step_id: int,
):
self.optim.step()
# NOTE: we only have one optimizer step for each stage, so micro_batch_id can be 0
tensor_buffer.put("stats", 0, dict(random_stat=random.random()))
class AdamWithLossScale(torch.optim.Adam):
@ -150,7 +150,6 @@ class MockTrainEngine(model_api.PipelinableEngine):
f"pp_size={constants.pipe_parallel_world_size()}, "
f"#tokens per mbs: {[mb.data['packed_input_ids'].shape[0] for mb in mb_inputs]}"
)
stat = collections.defaultdict(int)
for i, mb_input in enumerate(mb_inputs):
input_lens = torch.tensor(
flat2d(mb_input.seqlens["packed_input_ids"]),
@ -164,15 +163,13 @@ class MockTrainEngine(model_api.PipelinableEngine):
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).logits
loss, _stat = loss_fn(model_output, mb_input)
loss = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
loss_scale *= constants.data_parallel_world_size()
loss *= loss_scale
for k, v in _stat.items():
stat[k] += v
return stat
return dict(random_stat=random.random())
@torch.no_grad()
def forward(

View File

@ -670,10 +670,9 @@ class PipeTrainForwardCommInstrSet:
input_cache: SequenceSample = tensor_buffer.get(
"input_cache", micro_batch_id, remove=True
)
loss, stats = loss_fn(model_output, input_cache)
loss = loss_fn(model_output, input_cache)
loss = loss * tensor_buffer.get("loss_scale", micro_batch_id)
tensor_buffer.put("losses", micro_batch_id, loss)
tensor_buffer.put("stats", micro_batch_id, stats)
def _exec_send_activations(
module: ReaLModel,
@ -1057,13 +1056,14 @@ class PipelineRunner:
pipe_schedule=sched,
)
agg_stats = None
agg_stats = {}
stat = tensor_buffer.get("stats", 0, raise_error=False)
stats = [None for _ in range(constants.pipe_parallel_world_size())]
dist.all_gather_object(stats, stat, group=constants.pipe_parallel_cpu_group())
if constants.is_last_pipe_stage():
stats = []
for mbid in range(n_pp_mbs):
stats.append(tensor_buffer.get("stats", mbid))
agg_stats = dict()
for key in stats[0].keys():
agg_stats[key] = torch.stack([stat[key] for stat in stats]).sum()
agg_stats[key] = sum([stat[key] for stat in stats]) / len(stats)
return agg_stats

View File

@ -4,6 +4,7 @@ import asyncio
import dataclasses
import json
import os
import socket
import sys
import time
import traceback
@ -33,8 +34,11 @@ from realhf.api.core.model_api import (
from realhf.base import (
cluster,
constants,
datapack,
gpu_utils,
logging,
name_resolve,
names,
network,
pkg_version,
seeding,
@ -71,91 +75,86 @@ class SGLangAPIClient(LLMAPIClient):
"stop_token_ids": req.stop_token_ids,
}
payload = {
"input_ids": req.prompt_ids,
"input_ids": req.input_ids,
"sampling_params": sample_params,
"return_logprob": req.return_logprob,
"stream": stream,
}
output = APIGenerateOutput.from_input(req)
assert not stream, "streaming mode not yet implemented"
outputs = [APIGenerateOutput.from_input(req) for _ in range(gconfig.n)]
most_recent_timestamps = [time.perf_counter() for _ in range(gconfig.n)]
output_idx = 0
# The following code is partially adopted from sglang/bench_serving.py
output_ids = []
output_logprobs = []
finish_reason = {}
ttft = 0.0
latency = float("inf")
st = time.perf_counter()
most_recent_timestamp = st
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=None)
try:
async with self.session.post(
url=self.generate_url,
json=payload,
timeout=timeout,
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
async with self.session.post(url=self.generate_url, json=payload) as response:
response.raise_for_status()
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output_ids = data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
finish_reason = data["meta_info"]["finish_reason"]
output_logprobs = data["meta_info"][
"output_token_logprobs"
]
assert finish_reason["type"] in ["length", "stop"], finish_reason
output.output_logprobs = [x[0] for x in output_logprobs]
output.output_ids = output_ids
output.no_eos = finish_reason["type"] == "length"
output.success = True
output.latency = latency
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
raise RuntimeError(
f"SGLang generation request fails:\n{output.error}"
) from e
datas = json.loads(chunk)
if not isinstance(datas, list):
datas = [datas]
for data in datas:
return output
output = outputs[output_idx]
timestamp = time.perf_counter()
# First token
if output.ttft == float("inf"):
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamps[output_idx]
)
async def async_update_weights_from_disk(self, path):
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=None)
async with self.session.post(
url=self.update_weights_url,
json=dict(model_path=path),
timeout=timeout,
) as response:
if response.status != 200:
raise RuntimeError("Update weights failed.")
most_recent_timestamps[output_idx] = timestamp
output.output_ids = [data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]]
finish_reason = data["meta_info"]["finish_reason"]
if req.return_logprob:
output.output_logprobs = [
[
x[0]
for x in data["meta_info"]["output_token_logprobs"]
]
]
assert finish_reason["type"] in [
"length",
"stop",
], finish_reason
output.no_eos = [finish_reason["type"] == "length"]
output.latency = latency
output_idx += 1
return APIGenerateOutput.concat(outputs)
async def async_update_weights_from_disk(self, path, retries=5):
for _ in range(retries):
async with self.session.post(
url=self.update_weights_url,
json=dict(model_path=path),
) as resp:
if resp.status == 200:
res = await resp.json()
success = res["success"]
if success:
return
logger.warning(
f"Update weights failed: {res['message']}. Retrying."
)
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
time.sleep(0.1)
raise RuntimeError("Update weights failed.")
def sglang_server_process(server_args_dict):
@ -163,6 +162,9 @@ def sglang_server_process(server_args_dict):
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
if pkg_version.is_version_less("sglang", "0.4.4"):
server_args_dict.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
from sglang.srt.server import launch_server
@ -182,6 +184,7 @@ def sglang_server_process(server_args_dict):
)
try:
logger.info(f"SGLang Server Args: {server_args}")
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
@ -216,7 +219,24 @@ class SGLangGenerationEngine(PipelinableEngine):
"update_weights_from_disk": f"{self.base_url}/update_weights_from_disk",
}
asyncio.run(self.wait_server())
self.wait_server()
if server_args_dict["enable_metrics"]:
dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank()
mp_rank = constants.model_parallel_rank()
metric_server_name = f"d{dp_rank}p{pp_rank}m{mp_rank}"
key = names.metric_server(
constants.experiment_name(),
constants.trial_name(),
"sglang",
metric_server_name,
)
host_ip = server_args_dict["host"]
host_port = server_args_dict["port"]
address = f"{host_ip}:{host_port}"
name_resolve.add(key, address, keepalive_ttl=None, delete_on_exit=True)
logger.info(f"SGLang {metric_server_name} metrics URL: {address}")
self.request_timeout = request_timeout
@ -241,14 +261,14 @@ class SGLangGenerationEngine(PipelinableEngine):
def eval(self):
return self
async def wait_server(self):
def wait_server(self):
# Wait until the server is launched
from sglang.srt.utils import kill_process_tree
from sglang.utils import get_exception_traceback
success = False
for _ in range(SGLANG_INIT_TIMEOUT):
await asyncio.sleep(1)
time.sleep(1)
try:
res = requests.get(
self.base_url + "/get_model_info", timeout=5, headers={}
@ -283,7 +303,6 @@ class SGLangGenerationEngine(PipelinableEngine):
update_weights_url=self.api_urls["update_weights_from_disk"],
) as client:
tasks = []
input_queries = []
for d in input_.unpack():
if len(d.seqlens["packed_input_ids"]) > 1:
raise RuntimeError(
@ -293,35 +312,32 @@ class SGLangGenerationEngine(PipelinableEngine):
prompt_token_ids = d.data["packed_input_ids"].cpu().numpy().tolist()
qid = d.ids[0]
for group_idx in range(gconfig.n):
req = APIGenerateInput(
qid=qid,
group_idx=group_idx,
prompt_ids=prompt_token_ids,
input_ids=prompt_token_ids,
gconfig=gconfig.new(n=1),
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
return_logprob=True,
)
input_queries.append((qid, group_idx))
tasks.append(
client.async_add_generate_request(
req,
stream=stream,
)
req = APIGenerateInput(
qid=qid,
prompt_ids=prompt_token_ids,
input_ids=prompt_token_ids,
gconfig=gconfig,
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
return_logprob=True,
)
tasks.append(
client.async_add_generate_request(
req,
stream=stream,
)
)
outputs = {}
for r in asyncio.as_completed(tasks):
out = await r
outputs[(out.qid, out.group_idx)] = out
outputs[out.qid] = out
if pbar:
pbar.update(1)
if pbar is not None:
pbar.close()
results: List[APIGenerateOutput] = [outputs[key] for key in input_queries]
results: List[APIGenerateOutput] = [outputs[key] for key in input_.ids]
# Build the output: generated token ids, generated token scores,
# and logits mask (which will always be None in sglang).
@ -329,9 +345,9 @@ class SGLangGenerationEngine(PipelinableEngine):
batch_logprobs = []
max_seqlen = -1
for x in results:
max_seqlen = max(max_seqlen, len(x.output_ids))
batch_token_ids.append(x.output_ids)
batch_logprobs.append(x.output_logprobs)
max_seqlen = max(max_seqlen, max(x.output_lens))
batch_token_ids += x.output_ids
batch_logprobs += x.output_logprobs
# To be consistent with our internal implementation,
# we should pad generated tokens and logprobs
@ -415,16 +431,20 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
# For simplicity, we let all DP ranks have different ports.
ports = [None for _ in range(constants.data_parallel_world_size())]
while any(port is None for port in ports) or len(set(ports)) != len(ports):
while any(port is None for port in ports) or len(
set(datapack.flat2d(ports))
) != len(datapack.flat2d(ports)):
dist.all_gather_object(
ports,
network.find_free_port(low=20000, high=40000),
network.find_multiple_free_ports(2, low=20000, high=40000),
group=constants.data_parallel_group(),
)
additional_args["port"] = ports[constants.data_parallel_rank()]
api_server_port, dist_port = ports[constants.data_parallel_rank()]
additional_args["port"] = api_server_port
host_ip = socket.gethostbyname(socket.gethostname())
server_args_dict = dict(
host="localhost",
host="localhost" if not self.enable_metrics else host_ip,
# Model and tokenizer
tokenizer_path=self.model_path,
tokenizer_mode="auto",
@ -449,7 +469,7 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
# Expert parallelism
ep_size=1, # TODO: check
# Multi-node distributed serving
dist_init_addr=None,
dist_init_addr=f"{network.gethostip()}:{dist_port}",
nnodes=1,
node_rank=0,
**additional_args,

View File

@ -62,7 +62,7 @@ def extract_python_code(text, min_length=20, strict_syntax=True):
valid_blocks.append(clean_block)
if not valid_blocks:
logger.warning(f"failed to extract python code from {text}")
# logger.warning(f"failed to extract python code from {text}")
return None
# return the last code block
return valid_blocks[-1]
@ -128,7 +128,7 @@ def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
global id2info
assert len(answers) == len(query_id_strs)
format_rewards = []
if task == "math":
if task == "math" or task == "stem":
format_rewards = math_verify_call(id2info, answers, query_id_strs)
elif task == "code":
codes = [extract_python_code(_answer) for _answer in answers]

View File

@ -2,21 +2,21 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import collections
import dataclasses
import functools
import itertools
import time
from typing import Dict, Literal, Optional, Tuple
from typing import Dict, Literal, Optional
import torch
import torch.distributed as dist
import realhf.api.core.model_api as model_api
import realhf.base.constants as constants
import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.core.data_api import (
RL_TASKS,
MicroBatchSpec,
SequenceSample,
SequenceSplitSpec,
)
from realhf.base import constants, logging, stats_tracker
from realhf.base.datapack import flat2d
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
from realhf.impl.model.nn.real_llm_api import ReaLModel
@ -53,10 +53,11 @@ def _ppo_actor_loss_from_model_outputs(
input_: SequenceSample,
kl_adapter: ppo_functional.KLController, # const
eps_clip: float, # const
c_clip: float | None,
early_stop_imp_ratio: Optional[float], # const
early_stop_kl: Optional[float], # const
temperature: Optional[float] = 1,
) -> Tuple[torch.FloatTensor, Dict]:
) -> torch.Tensor:
"""Loss function for ppo actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
packed_input_ids = input_.data["packed_input_ids"]
@ -75,7 +76,6 @@ def _ppo_actor_loss_from_model_outputs(
if temperature is not None:
logits /= temperature
n_tokens = ppo_loss_mask.count_nonzero()
logprobs = gather_packed_shifted_log_probs(
logits, cu_seqlens, packed_input_ids
).float()
@ -85,26 +85,72 @@ def _ppo_actor_loss_from_model_outputs(
advantages=advantages,
eps_clip=eps_clip,
loss_mask=ppo_loss_mask,
c_clip=c_clip,
)
importance_weight = ppo_stat["importance_weight"].float() * n_tokens
clip_ratio = ppo_stat["clip_ratio"].float() * n_tokens
approx_kl = ppo_stat["approx_kl"].float() * n_tokens
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=ppo_loss_mask.bool(),
clipped_tokens=ppo_stat["clip_mask"],
dual_clipped_tokens=ppo_stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=ppo_stat["importance_weight"],
approx_kl=ppo_stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
actor_loss=ppo_stat["loss"],
clip_ratio=ppo_stat["clip_mask"].float(),
dual_clip_ratio=ppo_stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce(
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
)
dist.all_reduce(
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
)
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = ppo_stat["clip_mask"]
dual_clip_mask = ppo_stat["dual_clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
dual_clipped_new_logp = torch.where(dual_clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
dual_clipped_old_logp = torch.where(dual_clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
stats_tracker.stat(
dual_clipped_new_logp=dual_clipped_new_logp,
dual_clipped_old_logp=dual_clipped_old_logp,
denominator="dual_clipped_tokens",
)
# Logging and early stopping according to KL (logp vs ref) or importance ratio (new logp vs old logp).
mean_ref_kl = (kl_rewards.detach().float() * ppo_loss_mask).sum()
logging_loss = torch.where(ppo_loss_mask, loss.detach().float(), 0.0).sum()
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
dist.all_reduce(mean_ref_kl, group=constants.data_parallel_group())
dist.all_reduce(importance_weight, group=constants.data_parallel_group())
dist.all_reduce(clip_ratio, group=constants.data_parallel_group())
dist.all_reduce(approx_kl, group=constants.data_parallel_group())
dist.all_reduce(logging_loss, group=constants.data_parallel_group())
_imp = (ppo_stat["importance_weight"].float() * ppo_loss_mask).sum()
dist.all_reduce(_imp, group=constants.data_parallel_group())
_kl = (ppo_stat["approx_kl"].float() * ppo_loss_mask).sum()
dist.all_reduce(_kl, group=constants.data_parallel_group())
_n_valid_tokens = ppo_loss_mask.count_nonzero().clone()
dist.all_reduce(_n_valid_tokens, group=constants.data_parallel_group())
mean_ref_kl /= _n_valid_tokens
_imp /= _n_valid_tokens
_kl /= _n_valid_tokens
# Early stopping.
kl_adapter.update(mean_ref_kl / n_tokens, n_steps=cu_seqlens.shape[0] - 1)
_imp = importance_weight / n_tokens
_kl = approx_kl / n_tokens
kl_adapter.update(mean_ref_kl, n_steps=cu_seqlens.shape[0] - 1)
if early_stop_imp_ratio is not None and _imp > early_stop_imp_ratio:
logger.warning(
f"Current importance ratio {_imp.item():.4f} is larger "
@ -118,14 +164,7 @@ def _ppo_actor_loss_from_model_outputs(
)
loss = loss * 0.0
stats = dict(
ppo_approx_kl=approx_kl,
actor_loss=logging_loss,
actor_clip_ratio=clip_ratio,
importance_weight=importance_weight,
)
return loss, stats
return loss
def splited_sum_bool_tensor(t: torch.BoolTensor, chunk_size=256 * 1024 * 1024) -> int:
@ -158,6 +197,7 @@ class PPOActorInterface(model_api.ModelInterface):
gae_lambda: float = 1.0
eps_clip: float = 0.2
c_clip: Optional[float] = None
value_eps_clip: float = 0.2
max_reward_clip: float = 5.0
@ -188,6 +228,8 @@ class PPOActorInterface(model_api.ModelInterface):
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
sample_reuse: int = 1
def __post_init__(self):
if self.adaptive_kl_ctl:
assert self.adaptive_kl_target is not None
@ -468,14 +510,19 @@ class PPOActorInterface(model_api.ModelInterface):
# We call module.eval() because dropout causes the computation of incorrect of log probs.
module.eval()
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
prompt_mask = input_.data["prompt_mask"]
input_lens = torch.tensor(
flat2d(input_.seqlens["packed_input_ids"]), device=model.device
)
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
prompt_lens = []
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
prompt_lens.append(prompt_mask[s:e].sum())
prompt_lens = torch.tensor(prompt_lens, device=model.device)
reward_score = input_.data["rewards"].float()
task_ids = input_.data["task_ids"]
task_ids = task_ids.repeat(self.group_size, 1).transpose(0, 1).reshape(-1)
if "dense_rewards" in input_.data:
dense_reward_score = input_.data["dense_rewards"].float()
if not self.disable_value:
@ -485,6 +532,13 @@ class PPOActorInterface(model_api.ModelInterface):
input_.data["packed_input_ids"], dtype=torch.float32
)
seq_no_eos_mask = input_.data["seq_no_eos_mask"]
if self.kl_adapter.value == 0:
ref_logp: torch.FloatTensor = reward_score.new_zeros(
int(input_lens.sum()) - len(input_lens)
)
else:
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
if not self.disable_value:
if self.value_norm:
@ -602,7 +656,7 @@ class PPOActorInterface(model_api.ModelInterface):
advantages = torch.cat(adv_list, 0)
# Prepare data to be splitted into mini-batches.
input_ = SequenceSample.from_default(
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=dict(
advantages=advantages,
@ -613,108 +667,112 @@ class PPOActorInterface(model_api.ModelInterface):
),
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
if self.use_dense_reward:
dense_reward_score = dense_reward_score[shift_one_indices]
### Logging code starts. ###
_n_seqs = torch.tensor(
[reward_score.shape[0]], dtype=torch.float32, device=model.device
)
no_eos_ratios = seq_no_eos_mask.sum()
# _n_tokens = loss_mask.count_nonzero()
_n_tokens = prompt_mask.logical_not().count_nonzero()
_n_valid_tokens = loss_mask.count_nonzero()
task_reward = reward_score.sum()
if self.use_dense_reward:
dense_reward = (dense_reward_score * loss_mask).sum()
final_reward = (rewards * loss_mask).sum()
_advantages = advantages.sum()
_kl_rewards = (kl_rewards * loss_mask).sum()
prompt_len = prompt_mask.count_nonzero().float()
seq_len = input_lens.float().sum()
with stats_tracker.scope("ppo_actor"):
assert (
task_ids.shape == reward_score.shape
), f"task_ids ({task_ids.shape}) and reward_score ({reward_score.shape}) must have the same shape"
dist.all_reduce(_n_seqs, group=constants.data_parallel_group())
dist.all_reduce(no_eos_ratios, group=constants.data_parallel_group())
dist.all_reduce(task_reward, group=constants.data_parallel_group())
if self.use_dense_reward:
dist.all_reduce(dense_reward, group=constants.data_parallel_group())
dist.all_reduce(final_reward, group=constants.data_parallel_group())
dist.all_reduce(_advantages, group=constants.data_parallel_group())
dist.all_reduce(prompt_len, group=constants.data_parallel_group())
dist.all_reduce(seq_len, group=constants.data_parallel_group())
dist.all_reduce(_n_tokens, group=constants.data_parallel_group())
dist.all_reduce(_n_valid_tokens, group=constants.data_parallel_group())
dist.all_reduce(_kl_rewards, group=constants.data_parallel_group())
task_denominators = {
f"{task}_n_seqs": (task_ids == idx).bool()
for idx, task in enumerate(RL_TASKS)
}
global_stats = dict(
task_reward=float(task_reward / _n_seqs),
kl_reward=float(_kl_rewards / _n_tokens),
final_reward=float(final_reward / _n_seqs),
advantage=float(_advantages / _n_tokens),
avg_seq_len=float(seq_len / _n_seqs),
avg_prompt_len=float(prompt_len / _n_seqs),
n_tokens=int(_n_tokens),
n_valid_tokens=int(_n_valid_tokens),
n_seqs=int(_n_seqs),
no_eos_ratio=float(no_eos_ratios / _n_seqs),
disable_value=int(self.disable_value),
mask_no_eos_with_zero=int(self.mask_no_eos_with_zero),
)
if self.use_dense_reward:
dense_reward = (float(dense_reward / _n_seqs),)
stats_tracker.denominator(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**task_denominators,
)
### Logging code ends. ###
for task in RL_TASKS:
stats_tracker.stat(
**{f"{task}_reward": reward_score}, denominator=f"{task}_n_seqs"
)
# Run mini-batched PPO training!
train_stats = collections.defaultdict(lambda: 0)
stats = dict(
advantages=advantages,
kl_rewards=kl_rewards,
final_reward=rewards,
)
if self.use_dense_reward:
stats["dense_reward"] = dense_reward_score
stats_tracker.stat(**stats, denominator="n_valid_tokens")
for data in datas:
stats = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=functools.partial(
_ppo_actor_loss_from_model_outputs,
seq_stats = dict(
no_eos_ratios=seq_no_eos_mask.float(),
task_reward=reward_score,
prompt_len=prompt_lens.float(),
seq_len=input_lens.float(),
)
if "version_start" in input_.data:
seq_stats["head_offpolicyness"] = (
model.version.global_step - input_.data["version_start"]
).float()
if "version_end" in input_.data:
seq_stats["tail_offpolicyness"] = (
model.version.global_step - input_.data["version_end"]
).float()
stats_tracker.stat(
**seq_stats,
denominator="n_seqs",
)
# Run mini-batched PPO training!
def _loss_fn(logits, input_):
return _ppo_actor_loss_from_model_outputs(
logits,
input_,
kl_adapter=self.kl_adapter,
eps_clip=self.eps_clip,
early_stop_imp_ratio=self.early_stop_imp_ratio,
early_stop_kl=self.early_stop_kl,
c_clip=self.c_clip,
temperature=self.gconfig.temperature,
),
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
)
if stats:
for k, v in stats.items():
train_stats[k] += v
cur_epoch = model.version.epoch
for reuse in range(self.sample_reuse):
with stats_tracker.scope(f"reuse{reuse}"):
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
flat_input = SequenceSample.shuffled(flat_input)
bs = flat_input.bs
sizes = [0 for _ in range(self.n_minibatches)]
for idx in range(bs):
sizes[idx % self.n_minibatches] += 1
spec = SequenceSplitSpec(sizes=sizes)
datas = flat_input.split_with_spec(spec)
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
for mb_i, data in enumerate(datas):
with stats_tracker.scope(f"mb{mb_i}"):
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data[
"ppo_loss_mask"
].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
stats_tracker.scalar(**train_stat)
stats_tracker.scalar(
disable_value=self.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
c_clip=self.c_clip if self.c_clip is not None else float("nan"),
eps_clip=self.eps_clip,
)
model.inc_version()
# FIXME: It only logs the MoE aux loss of the final PPO mini-batch.
global_stats.update(
constants.log_global_stats_tracker(
return_dict=True, clear_stats_after_logging=True
)
)
if train_stats:
train_stats = dict(
ppo_approx_kl=float(train_stats["ppo_approx_kl"] / _n_tokens),
actor_loss=float(train_stats["actor_loss"] / _n_tokens),
actor_clip_ratio=float(train_stats["actor_clip_ratio"] / _n_tokens),
importance_weight=float(train_stats["importance_weight"] / _n_tokens),
)
train_stats = dict(**train_stats, **global_stats)
return dict(train_stats)
return stats_tracker.export()
# Mock methods for profiling only.
def _mock_inference(
@ -803,7 +861,7 @@ def _ppo_critic_loss_from_model_outputs(
value_eps_clip: float,
kl_adapter: ppo_functional.KLController,
rms=None,
) -> Tuple[torch.FloatTensor, Dict]:
) -> torch.Tensor:
cu_seqlens = (
torch.nn.functional.pad(
@ -846,27 +904,23 @@ def _ppo_critic_loss_from_model_outputs(
denormalized_values = new_values
# Logging.
n_tokens = ppo_loss_mask.count_nonzero()
mean_ref_kl = (kl_rewards.detach().float() * ppo_loss_mask).sum()
logging_loss = loss.detach().float() * n_tokens
clip_ratio = loss_stat["clip_ratio"].float() * n_tokens
denormalized_values = (
torch.where(ppo_loss_mask, denormalized_values, 0.0).sum().detach().float()
stats_tracker.denominator(n_valid_tokens=ppo_loss_mask.bool())
stats_tracker.stat(
value_loss=loss_stat["loss"],
clip_ratio=loss_stat["clip_mask"].float(),
denormalized_values=denormalized_values.detach().float(),
denominator="n_valid_tokens",
)
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
dist.all_reduce(mean_ref_kl, group=constants.data_parallel_group())
dist.all_reduce(logging_loss, group=constants.data_parallel_group())
dist.all_reduce(clip_ratio, group=constants.data_parallel_group())
dist.all_reduce(denormalized_values, group=constants.data_parallel_group())
# Update KL coefficient to be consistent with actor.
mean_ref_kl = (kl_rewards.detach().float() * ppo_loss_mask).sum()
dist.all_reduce(mean_ref_kl, group=constants.data_parallel_group())
_n_valid_tokens = ppo_loss_mask.count_nonzero().clone()
dist.all_reduce(_n_valid_tokens, group=constants.data_parallel_group())
mean_ref_kl /= _n_valid_tokens
kl_adapter.update(mean_ref_kl, n_steps=cu_seqlens.shape[0] - 1)
return loss, dict(
value_loss=logging_loss,
value_clip_ratio=clip_ratio,
denormalized_values=denormalized_values,
)
return loss
@dataclasses.dataclass
@ -896,6 +950,8 @@ class PPOCriticInterface(model_api.ModelInterface):
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
sample_reuse: int = 1
def __post_init__(self):
if self.adaptive_kl_ctl:
assert self.adaptive_kl_target is not None
@ -988,8 +1044,6 @@ class PPOCriticInterface(model_api.ModelInterface):
# We call module.eval() because dropout causes the computation of incorrect of log probs.
module.eval()
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
prompt_mask = input_.data["prompt_mask"]
input_lens = torch.tensor(
flat2d(input_.seqlens["packed_input_ids"]), device=model.device
@ -1000,6 +1054,13 @@ class PPOCriticInterface(model_api.ModelInterface):
dense_reward_score = input_.data["dense_rewards"].float()
values = input_.data["values"].float()
seq_no_eos_mask = input_.data["seq_no_eos_mask"]
if self.kl_adapter.value == 0:
ref_logp: torch.FloatTensor = reward_score.new_zeros(
int(input_lens.sum()) - len(input_lens)
)
else:
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
if self.value_norm:
denormalized_values = self.rms.denormalize(values)
@ -1080,7 +1141,7 @@ class PPOCriticInterface(model_api.ModelInterface):
normalized_returns = returns
# Prepare data to be splitted into mini-batches.
input_ = SequenceSample.from_default(
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=dict(
returns=normalized_returns,
@ -1091,63 +1152,54 @@ class PPOCriticInterface(model_api.ModelInterface):
),
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
# Logging.
returns = torch.where(loss_mask, returns, 0.0).sum()
n_tokens = loss_mask.count_nonzero()
dist.all_reduce(returns, group=constants.data_parallel_group())
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
global_stats = dict(returns=float(returns / n_tokens), n_tokens=int(n_tokens))
with stats_tracker.scope("ppo_critic"):
stats_tracker.denominator(n_valid_tokens=loss_mask)
stats_tracker.stat(returns=returns, denominator="n_valid_tokens")
# Run mini-batched PPO training!
train_stats = collections.defaultdict(lambda: 0)
for data in datas:
stats = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=functools.partial(
_ppo_critic_loss_from_model_outputs,
def _loss_fn(out, inp):
return _ppo_critic_loss_from_model_outputs(
out,
inp,
value_eps_clip=self.value_eps_clip,
kl_adapter=self.kl_adapter,
rms=None if not self.value_norm else self.rms,
),
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
)
if stats:
for k, v in stats.items():
train_stats[k] += v
# Run mini-batched PPO training!
for reuse in range(self.sample_reuse):
with stats_tracker.scope(f"reuse{reuse}"):
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
flat_input = SequenceSample.shuffled(flat_input)
bs = flat_input.bs
sizes = [0 for _ in range(self.n_minibatches)]
for idx in range(bs):
sizes[idx % self.n_minibatches] += 1
spec = SequenceSplitSpec(sizes=sizes)
datas = flat_input.split_with_spec(spec)
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
for mb_i, data in enumerate(datas):
with stats_tracker.scope(f"mb{mb_i}"):
stats = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data[
"ppo_loss_mask"
].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
stats_tracker.scalar(**stats)
cur_epoch = model.version.epoch
model.inc_version()
# FIXME: It only logs the MoE aux loss of the final PPO mini-batch.
global_stats.update(
constants.log_global_stats_tracker(
return_dict=True, clear_stats_after_logging=True
)
)
if train_stats:
train_stats = dict(
value_loss=float(train_stats["value_loss"] / n_tokens),
value_clip_ratio=float(train_stats["value_clip_ratio"] / n_tokens),
denormalized_values=float(
train_stats["denormalized_values"] / n_tokens
),
**global_stats,
)
return dict(train_stats)
return stats_tracker.export()
# Mock methods for profiling only.
def _mock_inference(

View File

@ -2,6 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
from typing import Dict, Literal
import torch
@ -10,8 +11,8 @@ import torch.utils.data
import tqdm
import realhf.api.core.model_api as model_api
import realhf.base.constants as constants
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base import constants, stats_tracker
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils.functional import (
@ -36,7 +37,7 @@ def compute_packed_sft_loss(
prompt_mask = prompt_mask[shift_one_indices]
logprobs = torch.where(prompt_mask, 0, logprobs)
loss_sum = -logprobs.sum()
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
@ -49,45 +50,39 @@ def compute_packed_sft_loss(
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp).sum() / (
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
m.numel() - m.count_nonzero()
)
logging_ppl = (-seqlogp).exp().sum()
token_denorm = prompt_mask.numel() - prompt_mask.count_nonzero()
seq_denorm = torch.tensor(
[cu_seqlens.shape[0] - 1], dtype=torch.float32, device=logits.device
## Loggin stats
stats_tracker.denominator(
n_seqs=torch.ones(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
)
# Logging loss and perplexity.
logging_loss = loss_sum.detach().clone()
logging_token_denorm = token_denorm.detach().clone().float()
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce(
logging_ppl, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
)
dist.all_reduce(
logging_loss,
op=dist.ReduceOp.SUM,
group=constants.data_parallel_group(),
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
)
dist.all_reduce(
seq_denorm, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
)
dist.all_reduce(
logging_token_denorm,
op=dist.ReduceOp.SUM,
group=constants.data_parallel_group(),
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
loss = loss_sum / token_denorm
return loss, {
"loss": logging_loss,
"ppl": logging_ppl,
"n_tokens": logging_token_denorm,
"n_seqs": seq_denorm,
}
return loss
@dataclasses.dataclass
class SFTInterface(model_api.ModelInterface):
token_normalize_scope: Literal["global", "dp"] = "global"
@ -98,32 +93,22 @@ class SFTInterface(model_api.ModelInterface):
module.train()
stat = module.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x.data["prompt_mask"]
.logical_not()
.count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
mb_spec=mb_spec,
version_steps=model.version.global_step,
)
with stats_tracker.scope("sft"):
stats = module.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x.data["prompt_mask"]
.logical_not()
.count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
mb_spec=mb_spec,
version_steps=model.version.global_step,
)
stats_tracker.scalar(**stats)
model.inc_version()
res = dict()
global_stats = constants.log_global_stats_tracker(
return_dict=True, clear_stats_after_logging=True
)
if stat:
res = dict(
loss=float(stat["loss"]) / int(stat["n_tokens"]),
ppl=float(stat["ppl"]) / int(stat["n_seqs"]),
n_tokens=int(stat["n_tokens"]),
n_seqs=int(stat["n_seqs"]),
**global_stats,
)
return res
return stats_tracker.export()
def save(self, model: model_api.Model, save_dir: str):
module = model.module
@ -144,36 +129,18 @@ class SFTInterface(model_api.ModelInterface):
module = model_.module
module.eval()
losses = n_seqs = ppl = n_tokens = 0
for step, x in enumerate(tqdm.tqdm(eval_dataloader)):
x: SequenceSample
res = module.eval_batch(
input_=x.to_device(device),
loss_fn=compute_packed_sft_loss,
mb_spec=MicroBatchSpec(),
)
with stats_tracker.scope("sft-eval"):
module.eval_batch(
input_=x.to_device(device),
loss_fn=compute_packed_sft_loss,
mb_spec=MicroBatchSpec(),
)
if res is not None:
_, stat = res
losses += stat["loss"]
n_tokens += stat["n_tokens"]
n_seqs += stat["n_seqs"]
ppl += stat["ppl"]
global_stats = constants.log_global_stats_tracker(
return_dict=True, clear_stats_after_logging=True
)
if res is not None:
return dict(
loss=float(losses / n_tokens),
ppl=float(ppl / n_seqs),
n_tokens=int(n_tokens),
n_seqs=int(n_seqs),
**global_stats,
)
return dict()
return stats_tracker.export()
model_api.register_interface("sft", SFTInterface)

View File

@ -403,19 +403,10 @@ def update_aux_losses_tracker(
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
"""
from realhf.base.stats_tracker import MOE_AUX_LOSSES
assert name in aux_loss_names, f"Invalid aux loss name: {name}."
losses = constants.get_from_global_stats_tracker(name)
losses = MOE_AUX_LOSSES.get(name, None)
if losses is None:
losses = torch.zeros(num_layers, device=loss.device)
losses[layer_number] += loss.detach()
constants.save_to_global_stats_tracker(
name, losses, hook=avg_aux_loss, stats_key=name
)
def avg_aux_loss(stats_key):
loss: torch.Tensor = constants.get_from_global_stats_tracker(stats_key)
dist.all_reduce(loss, group=constants.pipe_parallel_group())
loss = loss.mean()
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=constants.data_parallel_group())
constants.save_to_global_stats_tracker(stats_key, float(loss))

View File

@ -54,6 +54,7 @@ def actor_loss_fn(
advantages: torch.FloatTensor,
eps_clip: float,
loss_mask: Optional[torch.BoolTensor] = None,
c_clip: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
"""Compute PPO actor loss function.
@ -65,6 +66,8 @@ def actor_loss_fn(
old_logprobs (torch.FloatTensor): Old log probabilities of actions.
advantages (torch.FloatTensor): GAE (normalized) advantages.
eps_clip (float): Clip ratio of PPO.
c_clip (float | None): The dual clip factor.
Check https://arxiv.org/pdf/1912.09729 for details.
loss_mask (Optional[torch.BoolTensor], optional): Mask for loss computation.
1 if valid else 0. Defaults to None.
@ -85,41 +88,39 @@ def actor_loss_fn(
loss_mask_count = loss_mask.count_nonzero() or 1
# For numerical stability.
ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
approx_kl = torch.where(loss_mask, (logprobs - old_logprobs).detach(), 0.0)
else:
ratio = torch.exp(logprobs - old_logprobs)
approx_kl = (logprobs - old_logprobs).detach()
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
if loss_mask is not None:
pg_loss = (
torch.where(loss_mask, torch.max(pg_loss1, pg_loss2), 0).sum()
/ loss_mask_count
)
else:
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
clip_mask = pg_loss1.detach() < pg_loss2.detach()
if loss_mask is not None:
proportion_clipped = (
clip_mask.logical_and_(loss_mask).count_nonzero() / loss_mask_count
)
importance_weight = (
torch.where(loss_mask, ratio.detach(), 0).sum() / loss_mask_count
)
approx_kl = approx_kl.sum() / loss_mask_count
pg_loss = torch.max(pg_loss1, pg_loss2)
if c_clip is not None:
assert c_clip > 1.0, c_clip
pg_loss3 = torch.sign(advantages) * c_clip * advantages
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
pg_loss = torch.min(pg_loss, pg_loss3)
else:
proportion_clipped = clip_mask.count_nonzero()
importance_weight = ratio.detach().mean()
approx_kl = approx_kl.mean()
dual_clip_mask = torch.zeros_like(clip_mask)
logging_loss = pg_loss.detach()
if loss_mask is not None:
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
else:
pg_loss = pg_loss.mean()
if loss_mask is not None:
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
# Remain torch.CudaTensor here for all-reduce after train step.
stat = dict(
clip_ratio=proportion_clipped,
importance_weight=importance_weight,
approx_kl=approx_kl,
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - old_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
return pg_loss, stat
@ -187,17 +188,14 @@ def critic_loss_fn(
with torch.no_grad():
clip_mask = value_loss_clipped.detach() > value_loss_original.detach()
if loss_mask is not None:
mask_count = loss_mask.count_nonzero() or 1
proportion_clipped = (
clip_mask.logical_and_(loss_mask).count_nonzero() / mask_count
)
else:
proportion_clipped = clip_mask.count_nonzero()
clip_mask.logical_and_(loss_mask)
stat = dict(clip_ratio=proportion_clipped)
stat = dict(clip_mask=clip_mask, loss=value_loss.detach())
if loss_mask is not None:
value_loss = torch.where(loss_mask, value_loss, 0).sum() / mask_count
value_loss = (
torch.where(loss_mask, value_loss, 0).sum() / loss_mask.count_nonzero()
)
else:
value_loss = value_loss.mean()
@ -354,6 +352,10 @@ def get_packed_advantages_and_returns(
short1cu_seqlens: torch.IntTensor,
seq_no_eos_mask: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if rewards.get_device() == -1:
return pygae1d_nolp_misalign(
rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam
)
try:
return cugae1d_nolp_misalign_func(
rewards,

View File

@ -161,7 +161,16 @@ def make(mode, expr_name, trial_name, **kwargs) -> SchedulerClient:
schedule_strategy = kwargs.get("schedule_strategy", "empty_first")
evaluator = kwargs.get("evaluator", None)
return SlurmSchedulerClient(expr_name, trial_name, schedule_strategy, evaluator)
job_group_id = kwargs.get("job_group_id", None)
job_group_index = kwargs.get("job_group_index", None)
return SlurmSchedulerClient(
expr_name,
trial_name,
schedule_strategy,
evaluator,
job_group_id,
job_group_index,
)
elif mode == "local":
from realhf.scheduler.local.client import LocalSchedulerClient

View File

@ -38,6 +38,8 @@ class SlurmSchedulerClient(SchedulerClient):
trial_name: str,
schedule_strategy: str,
evaluator: Optional[AutomaticEvaluator],
job_group_id: str,
job_group_index: int,
):
super().__init__(expr_name, trial_name)
@ -49,6 +51,8 @@ class SlurmSchedulerClient(SchedulerClient):
self.__submission_counter = defaultdict(int)
self.__wprocs_counter = defaultdict(int)
self.__evaluator = evaluator
self.__job_group_id = job_group_id
self.__job_group_index = job_group_index
def submit(self, worker_type, cmd, **kwargs):
self.submit_array(worker_type, cmd, count=1, **kwargs)
@ -98,6 +102,8 @@ class SlurmSchedulerClient(SchedulerClient):
begin=begin,
deadline=deadline,
time_limit=time_limit,
job_group_id=self.__job_group_id,
job_group_index=self.__job_group_index,
)
if (

View File

@ -10,6 +10,7 @@ import collections
import dataclasses
import datetime
import getpass
import json
import math
import os
import shutil
@ -21,6 +22,7 @@ import pandas as pd
import realhf.base.cluster as cluster
import realhf.base.logging as logging
import realhf.version as version
from realhf.base.constants import LOG_ROOT
from realhf.scheduler.client import JobException, JobInfo, JobState
@ -224,6 +226,8 @@ class SlurmLaunchInfo:
worker_type: str
worker_submission_idx: int
wprocs_in_job: int
job_group_id: str
job_group_index: str
resource_requirement: SlurmResource
cmd: str
@ -264,14 +268,8 @@ class SlurmLaunchInfo:
f"GPU per worker {gpu_per_worker}, workers per jobstep (process size in `apps.remote`) {self.wprocs_per_jobstep}, "
f"number of jobsteps (instance of running `apps.remote`) {self.n_jobsteps}"
)
elif gpu_per_worker == 0:
self.wprocs_per_jobstep = self.wprocs_in_job
self.n_jobsteps = 1
elif gpu_per_worker == 1:
self.n_jobsteps = self.wprocs_in_job
self.wprocs_per_jobstep = 1
else:
self.n_jobsteps = 1
self.n_jobsteps = self.wprocs_in_job
self.wprocs_per_jobstep = 1
@property
@ -399,6 +397,18 @@ class SlurmLaunchInfo:
else:
gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}"
srun_env = os.environ.copy()
job_metadata = {
"user": srun_env.get("EMAILPREFIX", ""),
"version": version.__version__,
"branch": version.__branch__,
"commit": version.__commit__,
"dirty": version.__is_dirty__,
"job_group_id": self.job_group_id,
"job_group_index": self.job_group_index,
}
job_metadata_json = json.dumps(job_metadata)
lines = [
"#!/bin/bash",
f"#SBATCH --job-name={self.slurm_name}",
@ -414,9 +424,9 @@ class SlurmLaunchInfo:
f"#SBATCH --time={self.time_limit}" if self.time_limit else "",
f"#SBATCH --begin={self.begin}" if self.begin else "",
f"#SBATCH --deadline={self.deadline}" if self.deadline else "",
f"#SBATCH --comment='{job_metadata_json}'",
]
srun_env = os.environ.copy()
if self.hostfile:
srun_env["SLURM_HOSTFILE"] = self.hostfile_path
# Setup step command.
@ -789,12 +799,16 @@ def allocate_resources(
# (16 PPUs/8 GPUs by default)
batched_requirement = info.resource_requirement
batched_ntasks = 1
if info.resource_requirement.gpu > 0:
assert task_left % cluster.spec.n_gpus_per_node == 0
batched_ntasks = cluster.spec.n_gpus_per_node
batched_requirement = (
cluster.spec.n_gpus_per_node * info.resource_requirement
)
gpu_per_task = info.resource_requirement.gpu
if gpu_per_task > 0:
assert (
task_left * gpu_per_task % cluster.spec.n_gpus_per_node == 0
), (task_left, gpu_per_task)
assert (
cluster.spec.n_gpus_per_node % gpu_per_task == 0
), gpu_per_task
batched_ntasks = int(cluster.spec.n_gpus_per_node // gpu_per_task)
batched_requirement = batched_ntasks * info.resource_requirement
try:
resource = resource - batched_requirement
except InvalidGPUTypeException:

View File

@ -2,19 +2,25 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import importlib
import os
import traceback
from typing import Type
import realhf.api.core.system_api
import realhf.base.logging as logging
logger = logging.getLogger("system")
# NOTE: Workers are configured in the following order.
# Take special care when adding a new worker type.
WORKER_TYPES = ["model_worker", "master_worker"]
WORKER_TYPES = [
"generation_server",
"gserver_manager",
"model_worker",
"master_worker",
"rollout_worker",
]
def load_worker(worker_type: str) -> Type:
@ -55,7 +61,10 @@ def run_worker(
)
worker = worker_class(server=server)
try:
worker.run()
if worker_type in ["rollout_worker"]:
asyncio.run(worker.run_async())
else:
worker.run()
except Exception as e:
logger.error("Worker %s failed with exception: %s", worker_name, e)
logger.error(traceback.format_exc())

View File

@ -10,7 +10,9 @@ import getpass
import json
import os
import re
import socket
import sys
import threading
import time
import traceback
from dataclasses import asdict
@ -24,11 +26,48 @@ import torch
from omegaconf import OmegaConf
import realhf.api.core.system_api as system_api
from realhf.base import constants, gpu_utils, logging, name_resolve, names
from realhf.base import constants, gpu_utils, logging, name_resolve, names, pkg_version
from realhf.base.cluster import spec as cluster_spec
from realhf.system import WORKER_TYPES, load_worker, worker_base, worker_control
from realhf.system.worker_base import WorkerServerStatus as Wss
flask_available = False
if pkg_version.is_available("flask"):
from flask import Flask, jsonify
app = Flask(__name__)
@app.route("/discovery", methods=["GET"])
def discovery():
key = names.metric_server_root(
constants.experiment_name(), constants.trial_name()
)
addresses = name_resolve.get_subtree(key)
result = []
if len(addresses) > 0:
result.append(
{
"targets": addresses,
"labels": {
"experiment": constants.experiment_name(),
"trial": constants.trial_name(),
},
}
)
logger.info(f"Discover metric servers: {result}")
return jsonify(result)
def start_metric_discovery_server(port: int):
host_ip = socket.gethostbyname(socket.gethostname())
logger.info(f"Start metric discovery server: http://{host_ip}:{port}/discovery")
app.run(debug=False, use_reloader=False, host="0.0.0.0", port=port)
flask_available = True
CONNECTION_RETRY_AFTER_SECONDS = 360
logger = logging.getLogger("controller", "colored")
@ -91,7 +130,9 @@ class Controller:
):
# Scheduling and connecting to workers.
workers_configs = [
(k, getattr(setup, k), getattr(scheduling, k)) for k in WORKER_TYPES
(k, getattr(setup, k), getattr(scheduling, k))
for k in WORKER_TYPES
if len(getattr(setup, k)) > 0
]
# Sanity check for scheduling and configuration.
@ -123,6 +164,13 @@ class Controller:
logger.info(f"Configuration has {len(config)} {name}.")
def start(self, experiment: system_api.Experiment, ignore_worker_error=False):
if flask_available and experiment.metric_discovery_port > 0:
server_thread = threading.Thread(
target=start_metric_discovery_server,
args=(experiment.metric_discovery_port,),
)
server_thread.start()
if ignore_worker_error:
check_worker_status = ()
remove_worker_status = (
@ -146,7 +194,11 @@ class Controller:
for i, setup in enumerate(setups):
self.__check_consistent_scheduling(scheduling, setup, verbose=(i == 0))
worker_counts = [(k, len(getattr(setups[0], k))) for k in WORKER_TYPES]
worker_counts = [
(k, len(getattr(setups[0], k)))
for k in WORKER_TYPES
if len(getattr(setups[0], k)) > 0
]
name_resolve.add(
names.trial_registry(self.experiment_name, self.trial_name),
@ -229,6 +281,8 @@ class Controller:
)
try:
for name in WORKER_TYPES:
if len(getattr(setup, name)) == 0:
continue
worker_infos = [x.worker_info for x in getattr(setup, name)]
logger.info(f"Configuring Workers: {name}...")
@ -610,7 +664,9 @@ class RayController:
if not isinstance(setup, list):
setup = [setup]
worker_counts = [
(k, len(getattr(setup[0], k)), getattr(scheduling, k)) for k in WORKER_TYPES
(k, len(getattr(setup[0], k)), getattr(scheduling, k))
for k in WORKER_TYPES
if len(getattr(setup[0], k)) > 0
]
env_vars = constants.get_env_vars(

View File

@ -32,6 +32,7 @@ class FunctionExecutor:
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
summary_writer: SummaryWriter | None,
shuffle_dataset: bool,
):
self.func_calls: Dict[str, ModelFunctionCall] = {}
@ -67,6 +68,7 @@ class FunctionExecutor:
self.buffer = buffer
self.data_loading_dp_idx = -1
self.shuffle_dataset = shuffle_dataset
# Sort all MFCs in the topological order and
# calculate the width of each level.
@ -110,84 +112,71 @@ class FunctionExecutor:
self.ctrl.ids_to_clear.clear()
async def load_data(self):
src_rpc = self.src_rpc
src_rpc_model_name = src_rpc.model_name
buffer = self.buffer
ctrl = self.ctrl
dp_idx = self.data_loading_dp_idx
received_ids = set()
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
dp_idx += 1
dp_idx %= self.src_dp_size
resps = await self.stream.call_async(
handlers=[f"__data{dp_idx}__"],
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
handle_type="fetch",
datas=[None],
datas=[None for _ in range(self.src_dp_size)],
verbose=False,
)
x: DataBatchMeta | None = resps[0]
if x is None:
continue
if x.meta_sample is None:
continue
all_data = []
data_cnt = []
gpu_id_data = {}
for dp_rank, x in enumerate(resps):
x: DataBatchMeta | None
all_data = x.meta_sample.unpack()
if x is None:
data_cnt.append(0)
continue
if x.meta_sample is None:
data_cnt.append(0)
continue
filtered_data = []
ids_to_ignore = []
for xx in x.meta_sample.unpack():
async with ctrl.lock:
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
ids_to_ignore.append(xx.ids[0])
else:
for xx in x.meta_sample.unpack():
async with ctrl.lock:
if xx.ids[0] in received_ids:
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
received_ids.add(xx.ids[0])
filtered_data.append(xx)
if ids_to_ignore:
# Clear ignored data.
self.stream.request(
handlers=list(range(self.n_model_workers)),
handle_type="clear_data_cache",
datas=[ids_to_ignore for _ in list(range(self.n_model_workers))],
no_syn=True,
)
gpu_id = self.stream.route_to(f"__data{dp_rank}__")
all_data += x.meta_sample.unpack()
gpu_id_data[gpu_id] = x.meta_sample.unpack()
data_cnt.append(x.meta_sample.bs)
all_data = filtered_data
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
if self.shuffle_dataset:
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
if len(all_data) > 0:
# Update resource tracker for planning data redistribution.
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
for k in all_data[0].keys:
await self.storage_tracker.add_data(
gpu_id,
[x.ids[0] for x in all_data],
k,
is_owner=True,
)
for gpu_id, data in gpu_id_data.items():
for k in data[0].keys:
await self.storage_tracker.add_data(
gpu_id,
[d.ids[0] for d in data],
k,
is_owner=True,
)
# Store into buffer!
buffer_indices = await buffer.put_batch(all_data)
assert len(buffer_indices) == len(all_data)
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data from DP rank {dp_idx}. "
f"Remaining number of data to ignore: {len(self.ctrl.hash_vals_to_ignore_in_recover)}. "
f"Current buffer size: {buffer.size}/{buffer.max_size}. "
)
self.data_loading_dp_idx = dp_idx
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data from all dp ranks: "
f"{data_cnt} from each rank. "
f"Current buffer size: {buffer.size}/{buffer.max_size}. "
)
else:
await asyncio.sleep(1)
def execute_step(self):
logger.info("Waiting for the finish of the execution graph.")

View File

@ -0,0 +1,85 @@
import os
import time
from realhf.api.cli_args import SGLangConfig
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
from realhf.base import gpu_utils, logging, name_resolve, names, network, seeding
from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger(__name__)
class GenerationServer(Worker):
def _configure(self, config: GenerationServerConfig):
self.config = config
self.worker_index = config.worker_info.worker_index
self.worker_count = config.worker_info.worker_count
self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name
seeding.set_random_seed(
config.base_seed, f"generation_server{self.worker_index}"
)
# Cancel the effect of CUDA device isolation
if "CUDA_VISIBLE_DEVICES" in os.environ:
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(gpu_utils.gpu_count()))
)
self.server_process = None
self.server_addr = None
return config.worker_info
def launch_server_subprocess(self):
config = self.config
assert config.backend_type == "sglang"
cmd = SGLangConfig.build_cmd(
config.backend_args,
config.model_path,
tp_size=config.tp_size,
server_index=self.worker_index,
base_gpu_id=self.base_gpu_id,
)
from sglang.utils import launch_server_cmd, wait_for_server
host_ip = network.gethostip()
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# TODO: handle launching error and retry
self.server_process, self.server_port = launch_server_cmd(cmd)
self.server_addr = f"http://{host}:{self.server_port}"
wait_for_server(self.server_addr)
name = names.gen_servers(self.experiment_name, self.trial_name)
name_resolve.add_subentry(name, self.server_addr)
key = names.metric_server(
self.experiment_name,
self.trial_name,
"sglang",
f"server{self.worker_index}",
)
name_resolve.add(
key, f"{host}:{self.server_port}", keepalive_ttl=None, delete_on_exit=True
)
logger.info(f"SGLang server launched at: {self.server_addr}")
def _poll(self):
if self.server_process is None:
self.launch_server_subprocess()
# TODO: we may want to collect some metrics from the server
time.sleep(0.05)
return PollResult(0, 0)
def _exit_hook(self, exit_status):
if self.server_process is not None and self.config.backend_type == "sglang":
from sglang.utils import terminate_process
terminate_process(self.server_process)

View File

@ -0,0 +1,354 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import os
import shutil
import threading
import time
from collections import defaultdict
from typing import List
import aiohttp
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
from realhf.base import constants, logging, name_resolve, names, network, recover
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
logger = logging.getLogger("Generation Manager", "colored")
STALENESS_WARNED = defaultdict(lambda: False)
class GserverManager(Worker):
"""This worker has the following functionalities:
1. As a router, it schedules generation requests and returns the
best server urls to clients for submitting generation requests.
2. It manages the weight update requests of generation servers.
The weight update manager must be unique in each experiment.
This is currently a hack usage of SGLang. We can integrate the
functionalities into sgl-router and srt in the future.
"""
def _configure(self, config: GserverManagerConfig):
self.config = config
self.model_name = config.model_name
assert self.config.worker_info.worker_count == 1
self.async_lock = asyncio.Lock()
self.threading_lock = threading.Lock()
self.n_total_rollouts = 0
self.n_running_rollouts = 0
self.accepted_rollouts = 0
self.schedule_policy = config.schedule_policy
self._last_param_realloc_step = 0
self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name
# manager server
self.server = None
self.thread = None
# recover info
self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_run:
# update weights will be automatically triggered upon the first schedule_request
# self._last_param_realloc_step will also be updated
name = names.model_version(
constants.experiment_name(),
constants.trial_name(),
self.model_name.role,
)
name_resolve.add(name, self.__recover_info.last_step_info.global_step)
self._loaded_recover_weights = False
self.n_total_rollouts = self.accepted_rollouts = (
self.config.train_batch_size
* self.__recover_info.last_step_info.global_step
)
return config.worker_info
def _discover_servers(self, n_servers: int, timeout: int = 300) -> List[str]:
logger.info(f"Waiting for {n_servers} generation servers...")
name = names.gen_servers(self.experiment_name, self.trial_name)
cnt = 0
while len(name_resolve.find_subtree(name)) < n_servers:
time.sleep(1)
cnt += 1
if cnt >= timeout:
raise TimeoutError("Waiting generation servers timeout.")
urls = name_resolve.get_subtree(name)
assert len(set(urls)) == len(urls), urls
return urls
def _get_recover_ckpt_path(self, role: str):
assert self.__recover_run
epoch = self.__recover_info.last_step_info.epoch + 1
epochstep = self.__recover_info.last_step_info.epoch_step + 1
globalstep = self.__recover_info.last_step_info.global_step + 1
save_root = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
)
role_path = os.path.join(save_root, role)
if not os.path.exists(role_path):
raise RuntimeError(
f"Guessed checkpoint path {role_path} does not exist. "
"Skip loading checkpoints in the recovered run."
)
model_path = os.path.join(
role_path,
f"epoch{epoch}epochstep{epochstep}globalstep{globalstep}",
)
if not os.path.exists(model_path):
raise RuntimeError(
f"Guessed checkpoint path {model_path} does not exist. "
"Skip loading checkpoints in the recovered run."
)
return model_path
def check_new_params(self) -> str | None:
name = names.model_version(
constants.experiment_name(),
constants.trial_name(),
self.model_name.role,
)
try:
realloc_version = int(name_resolve.get(name))
except name_resolve.NameEntryNotFoundError:
return None
# Update the model weights after parameter realloction.
if realloc_version > self._last_param_realloc_step:
if self.__recover_run and not self._loaded_recover_weights:
realloc_dir = self._get_recover_ckpt_path(self.model_name.role)
self._loaded_recover_weights = True
else:
realloc_dir = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
self.model_name.role,
str(realloc_version),
)
self._last_param_realloc_step = realloc_version
return realloc_dir
return None
async def flush_requests_and_update_weights(
self, server_url, new_param_path, update_weights_retries=5
):
# HACK: urls are designed for SGLang
server_index = self.server_urls.index(server_url)
async with aiohttp.ClientSession(server_url) as session:
running_requests = None
tik = time.perf_counter()
while running_requests is None or running_requests > 0:
if time.perf_counter() - tik > self.config.flush_request_timeout:
raise RuntimeError(
f"Waiting for flush requests failed. {running_requests} requests "
f"remain after {self.config.flush_request_timeout} secs waiting. "
f"Please try to reduce `new_tokens_per_chunk`."
)
if running_requests is not None and running_requests > 0:
logger.info(
f"Waiting for {running_requests} requests on gen server {server_index}... "
f"Time taken so far: {time.perf_counter() - tik:.4f}s"
)
await asyncio.sleep(0.5)
async with session.get(f"/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for line in text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
running_requests = float(line.split(" ")[1])
break
success = False
for _ in range(update_weights_retries):
async with session.post(
f"/update_weights_from_disk",
json=dict(model_path=new_param_path),
) as resp:
if resp.status == 200:
res = await resp.json()
success = res["success"]
if success:
return
logger.warning(
f"Update weights failed: {res['message']}. Retrying."
)
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
time.sleep(0.1)
raise RuntimeError("Update weights failed.")
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
if not hasattr(self, "round_robin_idx"):
self.round_robin_idx = 0
r = self.round_robin_idx
self.round_robin_idx += 1
self.round_robin_idx %= self.config.n_servers
return r
def _poll(self):
if not self.thread:
# Find addresses of generation servers
self.server_urls = self._discover_servers(self.config.n_servers)
self.thread = threading.Thread(
target=self._run_routing_service, daemon=True
)
self.thread.start()
time.sleep(3) # Wait briefly for server to start
# Write address for clients
name = names.gen_server_manager(self.experiment_name, self.trial_name)
name_resolve.add(name, self.manager_addr)
logger.info(
f"GserverManager HTTP service started in background thread at {self.manager_addr}"
)
# Check weights.
with self.threading_lock:
# FIXME: we create a sync point across servers to update weights,
# but we can acutally update them individually
new_param_path = self.check_new_params()
if new_param_path is not None:
tasks = [
self.flush_requests_and_update_weights(base_url, new_param_path)
for base_url in self.server_urls
]
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(*tasks))
logger.info(f"Generaion server updated weights from: {new_param_path}")
# clear old weights
realloc_root = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
self.model_name.role,
)
if os.path.exists(realloc_root):
for realloc_version in os.listdir(realloc_root):
if (
os.path.isdir(os.path.join(realloc_root, realloc_version))
and int(realloc_version) < self._last_param_realloc_step
):
shutil.rmtree(os.path.join(realloc_root, realloc_version))
logger.info(
f"Removed previous reallocated "
f"checkpoint: {os.path.join(realloc_root, realloc_version)}"
)
# TODO: we may want to update server status
# in the main thread.
time.sleep(1)
return PollResult(0, 0)
async def is_staled(self):
global_sample_cnt = self.n_total_rollouts
expected_version = global_sample_cnt // self.config.train_batch_size
staled = (
expected_version
> self.config.max_head_offpolicyness + self._last_param_realloc_step
)
global STALENESS_WARNED
if staled and not STALENESS_WARNED[self._last_param_realloc_step]:
logger.warning(
f"expected version ({expected_version}) = "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"current version {self._last_param_realloc_step}, "
f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}"
)
STALENESS_WARNED[self._last_param_realloc_step] = True
return staled
def _run_routing_service(self):
"""Expose an API for clients to find the destination server."""
import uvicorn
from fastapi import FastAPI
self.app = FastAPI()
@self.app.post("/schedule_request")
async def schedule_request(req_meta: GenReqMeta):
with self.threading_lock:
async with self.async_lock:
version = self._last_param_realloc_step
# FIXME: We only implement a round-robin scheduler that
# ignores server status and request metadata
server_idx = self._round_robin_schedule(req_meta)
return dict(url=self.server_urls[server_idx], version=max(0, version))
@self.app.post("/get_model_version")
async def get_model_version(req: ModelVersionReq):
with self.threading_lock:
async with self.async_lock:
# FIXME: we may have different versions for different servers
version = self._last_param_realloc_step
return dict(version=version)
@self.app.get("/allocate_rollout")
async def allocate_rollout():
with self.threading_lock:
async with self.async_lock:
has_capacity = (
self.n_running_rollouts < self.config.max_concurrent_rollouts
)
is_staled = await self.is_staled()
reason = ""
if has_capacity and not is_staled:
self.n_running_rollouts += 1
self.n_total_rollouts += 1
return dict(success=True, reason=reason)
else:
if not has_capacity:
reason += f"capacity: {self.n_running_rollouts} >= {self.config.max_concurrent_rollouts}"
if is_staled:
global_sample_cnt = self.n_total_rollouts
expected_version = (
global_sample_cnt // self.config.train_batch_size
)
reason += (
f" and staled: expected version ({expected_version}) = "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"current version {self._last_param_realloc_step}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
return dict(success=False, reason=reason)
@self.app.post("/finish_rollout")
async def finish_rollout(resp_meta: GenRespMeta):
with self.threading_lock:
async with self.async_lock:
self.n_running_rollouts -= 1
if resp_meta.accepted:
self.accepted_rollouts += 1
return dict(success=True)
self.manager_addr = f"{network.gethostip()}:{network.find_free_port()}"
config = uvicorn.Config(
self.app,
host=self.manager_addr.split(":")[0],
port=int(self.manager_addr.split(":")[1]),
log_level="warning",
)
self.server = uvicorn.Server(config)
self.server.run()
def _exit_hook(self, exit_status):
if self.server:
self.server.should_exit = True
if self.thread:
self.thread.join(timeout=3)
logger.info("Server stopped")

View File

@ -15,13 +15,6 @@ import numpy as np
import wandb
from tensorboardX import SummaryWriter
try:
import uvloop
uvloop.install()
except (ModuleNotFoundError, ImportError):
pass
import realhf.api.core.dfg as dfg
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_pkg
@ -53,7 +46,7 @@ class MasterWorker(worker_base.Worker):
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
seeding.set_random_seed(self.config.base_seed, "master_worker")
self.__model_topos: Dict[ModelName, topology.ProcessTopology] = (
config.model_topos
@ -133,11 +126,6 @@ class MasterWorker(worker_base.Worker):
if self.__recover_run
else list()
),
hash_vals_to_ignore_in_recover=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
)
if self.__recover_run:
@ -222,22 +210,15 @@ class MasterWorker(worker_base.Worker):
src_rpc_dp_size = src_rpc_topo.get_dim("data")
# Request training specification from data workers.
all_data = sum(
self._dataset_size = sum(
self.__stream.call(
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)],
handle_type="spec",
),
[],
)
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
# such that the learning rate decay is not affected.
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
self._dataset_size = len(all_data)
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
# Request model configs from model workers.
# Return None if the model is not a ReaLModel.
@ -324,6 +305,7 @@ class MasterWorker(worker_base.Worker):
model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl,
summary_writer=self.__summary_writer,
shuffle_dataset=self.config.shuffle_dataset,
)
if self.__recover_run:
self.func_executor.data_loading_dp_idx = (
@ -455,9 +437,11 @@ class MasterWorker(worker_base.Worker):
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}) finishes. "
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
logging.log_wandb_tensorboard(
{"timeperf/e2e": e2e_time}, step=self.__rpc_ctrl.step_info.global_step
)
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch

View File

@ -22,7 +22,7 @@ import realhf.system.request_reply_stream as request_reply_stream
from realhf import ModelShardID
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, topology
from realhf.base import constants, logging, stats_tracker, topology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
from realhf.system.redistributor import RedistribPlanner, RedistribStep
@ -51,7 +51,6 @@ class RPCCorountineControl:
# recover information
used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list)
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
class ModelFunctionCall:
@ -412,7 +411,7 @@ class ModelFunctionCall:
# Filter out responses other than DP heads.
# Other repsonses are duplicated or None.
responses = [responses[i] for i in dp_head_indices]
responses, time_records = list(zip(*[responses[i] for i in dp_head_indices]))
# If the returned data is a SequenceSample, it is the data returned by
# model function calls. The data shoulbe be amended into buffer.
@ -439,19 +438,32 @@ class ModelFunctionCall:
res = data_api.gather_stat(responses)
if rpc.log_return_value:
logger.info(f"RPC name {rpc.name} returns {res}")
if isinstance(res, dict):
logger.info(
f"RPC name {rpc.name} returns\n{data_api.tabulate_stats(res)}"
)
logging.log_wandb_tensorboard(
res,
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer,
)
else:
logger.info(f"RPC name {rpc.name} returns\n{res}")
if isinstance(res, Dict):
wandb.log(res, step=ctrl.step_info.global_step)
if self.summary_writer is not None:
for key, val in res.items():
self.summary_writer.add_scalar(
f"{key}", val, ctrl.step_info.global_step
)
# Log rpc execution time.
for time_record in time_records:
stats_tracker.scalar(**time_record)
time_stats = stats_tracker.export()
logging.log_wandb_tensorboard(
time_stats,
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer,
)
logger.info(
f"Model rpc {rpc.name} finished. "
f"Run time {time.perf_counter() - tik:.4f}s."
f"Request-reply time {time.perf_counter() - tik:.4f}s. "
f"Detailed time stats:\n{data_api.tabulate_stats(time_stats, floatfmt='.2f')}."
)
# If this RPC is the final node in the dataflow graph,

View File

@ -2,7 +2,6 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import collections
import contextlib
import copy
import gc
@ -17,6 +16,7 @@ import shutil
import socket
import time
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Hashable, List, Optional, Set, Tuple
@ -54,6 +54,7 @@ from realhf.impl.model.utils import cuda_graph
from realhf.system import request_reply_stream, worker_base
from realhf.system.data_manager import DataManager
from realhf.system.redistributor import RedistribStep
from realhf.system.stream_dataset import PullerStreamDataset
# NOTE: Register all implemented datasets and models.
import realhf.impl.dataset # isort:skip
@ -119,7 +120,7 @@ class ModelWorker(worker_base.Worker):
self.__worker_index = cfg.worker_info.worker_index
seeding.set_random_seed(cfg.base_seed + self.__worker_index)
seeding.set_random_seed(cfg.base_seed, f"model_worker{self.__worker_index}")
# Reveal process group identity of this worker to world.
gpu_utils.reveal_pg_identity(
@ -143,6 +144,28 @@ class ModelWorker(worker_base.Worker):
self.__enable_memory_dump = os.getenv("REAL_DUMP_MEMORY", "0") == "1"
self.__performance_recorder = dict()
# Add an additional subscript pattern for source RPCs.
self.__has_dataset = False
self.__dataset_dp_size = self.__dataset_dp_rank = 0
sub_patterns = [s.id for s in self.config.shards]
self.src_rpc = src_rpc = [rpc for rpc in self.config.model_rpcs if rpc.is_src][
0
]
for s in self.config.shards:
_pp_size = s.id.topo.get_dim("pipe")
if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1):
continue
if src_rpc.model_name == s.id.model_name:
self.__has_dataset = True
self.__dataset_dp_size = s.id.topo.get_dim("data")
self.__dataset_dp_rank = s.id.dp_rank
sub_patterns.append(f"__data{self.__dataset_dp_rank}__")
break
if self.__has_dataset:
name = names.stream_pullers(self.__experiment_name, self.__trial_name)
name_resolve.add_subentry(name, str(self.__dataset_dp_rank))
return r
def _get_recover_ckpt_path(self, role: str):
@ -220,22 +243,6 @@ class ModelWorker(worker_base.Worker):
return self.__backends[constants.model_name()]
def __lazy_setup(self):
# Add an additional subscript pattern for source RPCs.
self.__has_dataset = False
self.__dataset_dp_size = self.__dataset_dp_rank = 0
sub_patterns = [s.id for s in self.config.shards]
src_rpc = [rpc for rpc in self.config.model_rpcs if rpc.is_src][0]
self.__src_rpc_model_name = src_rpc.model_name
for s in self.config.shards:
_pp_size = s.id.topo.get_dim("pipe")
if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1):
continue
if src_rpc.model_name == s.id.model_name:
self.__has_dataset = True
self.__dataset_dp_size = s.id.topo.get_dim("data")
self.__dataset_dp_rank = s.id.dp_rank
sub_patterns.append(f"__data{self.__dataset_dp_rank}__")
break
# Build stream connecting with master workers.
self.__stream = request_reply_stream.make_worker_stream(
@ -321,19 +328,22 @@ class ModelWorker(worker_base.Worker):
g = torch.Generator()
g.manual_seed(seeding.get_seed())
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
dataloader_kwargs = dict(
shuffle=self.config.shuffle_dataset,
generator=g,
)
if not isinstance(self.__dataset, PullerStreamDataset):
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
dataloader_kwargs["batch_size"] = 10240
else:
dataloader_kwargs["batch_size"] = None
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset, **dataloader_kwargs
)
self.__raw_samples = []
for tmp_sample in self.__dataloader:
self.__raw_samples += tmp_sample.meta().unpack()
self.dataset_size = len(self.__dataset)
self.__data_generator = enumerate(self.__dataloader)
@ -368,7 +378,7 @@ class ModelWorker(worker_base.Worker):
# Recover indices for dynamic dataset
if (
s.id.model_name == src_rpc.model_name
s.id.model_name == self.src_rpc.model_name
and self.__has_dataset
and hasattr(self.__dataset, "filter")
):
@ -471,13 +481,6 @@ class ModelWorker(worker_base.Worker):
for model_name in self.__models.keys()
}
# By intention, must be smaller than -1.
self._last_param_realloc_step = -100
if self.__recover_run:
self._last_param_realloc_step = (
self.__recover_info.last_step_info.global_step
)
def __handle_one_rpc_hook(self, hook: str, hook_data: Any):
ret = None
@ -534,7 +537,9 @@ class ModelWorker(worker_base.Worker):
cache = []
while True:
try:
request, data, handled, res = self.__request_queue.get_nowait()
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
request: request_reply_stream.Payload
if not handled:
while len(request.pre_hooks) > 0:
@ -548,11 +553,14 @@ class ModelWorker(worker_base.Worker):
f"The current hook is `{request.pre_hooks[0]}`. "
f"{self.__request_queue.qsize()} requests left to handle their potential pre-hooks."
)
self.__handle_one_rpc_hook(
request.pre_hooks.pop(0),
request.pre_hook_data.pop(0),
)
cache.append((request, data, handled, res))
tik = time.perf_counter()
hook = request.pre_hooks.pop(0)
hook_data = request.pre_hook_data.pop(0)
self.__handle_one_rpc_hook(hook, hook_data)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/pre-{hook}"
] += (time.perf_counter() - tik)
cache.append((request, data, handled, res, time_record))
except queue.Empty:
break
@ -607,21 +615,37 @@ class ModelWorker(worker_base.Worker):
)
g = torch.Generator()
g = g.set_state(self.__dataloader.generator.get_state())
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset,
collate_fn=data_api.SequenceSample.gather,
dataloader_kwargs = dict(
shuffle=self.config.shuffle_dataset,
generator=g,
)
if not isinstance(self.__dataset, PullerStreamDataset):
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
generator=g,
dataloader_kwargs["batch_size"] = 10240
else:
dataloader_kwargs["batch_size"] = None
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset, **dataloader_kwargs
)
self.__data_generator = enumerate(self.__dataloader)
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
# Defer data that has not been used in the previous epoch.
if isinstance(cur_sample, data_api.SequenceSample):
samples = cur_sample.unpack()
else:
assert isinstance(cur_sample, list), type(cur_sample)
samples = cur_sample
data_loaded = []
for x in cur_sample.unpack():
for x in samples:
if (
self.__recover_run
and x.ids[0] in self.__recover_info.hash_vals_to_ignore
):
self.__recover_info.hash_vals_to_ignore.remove(x.ids[0])
continue
if self.data_manager.has_data(x.ids[0]):
continue
data_loaded.append(x)
@ -639,7 +663,7 @@ class ModelWorker(worker_base.Worker):
)
elif request.handle_name == "spec":
# Raw dataset without filtering.
res = self.__raw_samples
res = self.dataset_size
elif request.handle_name == "clear_data_cache":
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
ids = request.data
@ -670,6 +694,7 @@ class ModelWorker(worker_base.Worker):
data: Any,
handled: bool,
res: Optional[Any],
time_record: Dict,
) -> worker_base.PollResult:
tik = time.perf_counter()
@ -734,25 +759,35 @@ class ModelWorker(worker_base.Worker):
f"request *{request.handle_name}*"
f" in ${time.perf_counter() - tik:.4f}$s"
)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/main"
] += (time.perf_counter() - tik)
# Handle all post hooks right after the main computation
if len(request.post_hooks) > 0:
assert len(request.post_hooks) == len(request.post_hook_data)
for hook, hook_data in zip(request.post_hooks, request.post_hook_data):
tik = time.perf_counter()
ret = self.__handle_one_rpc_hook(hook, hook_data)
if hook == "evaluate":
assert request.handle_name == "train_step", request.handle_name
assert isinstance(ret, dict), ret
assert isinstance(res, dict), res
res.update({f"eval_{k}": v for k, v in ret.items()})
res.update(ret)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}"
] += (time.perf_counter() - tik)
# update param realloc step after handling post hooks
if request.handle_name == "train_step":
self._last_param_realloc_step = max(self._last_param_realloc_step + 1, 1)
tik = time.perf_counter()
global_step = self.__models[model_name].version.global_step
realloc_dir = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
model_name.role,
str(global_step),
)
save_meta = dict(
model_name=model_name,
@ -766,10 +801,20 @@ class ModelWorker(worker_base.Worker):
model_name.role,
)
with constants.model_scope(model_name):
dist.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
if constants.parallelism_rank() == 0:
name_resolve.add_subentry(name, str(self._last_param_realloc_step))
name_resolve.add(
name,
str(global_step),
delete_on_exit=False,
keepalive_ttl=30,
replace=True,
)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save"
] += (time.perf_counter() - tik)
res = (res, time_record)
self.__reply_queue.put_nowait((request, res))
sample_count = data.bs if isinstance(data, data_api.SequenceSample) else 1
self.__request_sample_size[request.request_id] = sample_count
@ -1022,7 +1067,7 @@ class ModelWorker(worker_base.Worker):
# If the source is not a trainable model, it will not own
# parameters, so we just release its GPU memory.
with constants.model_scope(from_model_name):
from_model_ranks = constants.parallelism_group_ranks()
from_model_ranks = sorted(constants.parallelism_group_ranks())
if not param_realloc_comm.is_trainable(from_model_name):
if dist.get_rank() not in from_model_ranks:
return
@ -1037,11 +1082,28 @@ class ModelWorker(worker_base.Worker):
m.contiguous_param = dummy_tensor
return
# Get global_step from source model via broadcast,
# since there are no global_step information on model workers for generation.
if (
from_model_name in self.__models
and dist.get_rank() == from_model_ranks[0]
):
global_step = self.__models[from_model_name].version.global_step
else:
global_step = 0
g = self.__param_realloc_info.param_realloc_model_cpu_group[
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
]
global_step = torch.tensor(global_step, device="cpu")
dist.broadcast(global_step, src=from_model_ranks[0], group=g)
global_step = int(global_step.item())
realloc_dir = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
from_model_name.role,
str(global_step),
)
if from_model_name in self.__unwrapped_models:
save_meta = dict(
@ -1050,9 +1112,6 @@ class ModelWorker(worker_base.Worker):
save_dir=realloc_dir,
)
self.__save_model(save_meta)
g = self.__param_realloc_info.param_realloc_model_cpu_group[
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
]
dist.barrier(group=g)
if to_model_name in self.__unwrapped_models:
load_meta = dict(
@ -1229,7 +1288,9 @@ class ModelWorker(worker_base.Worker):
self.__ack_cache[r.request_id] = r
else:
if r.no_syn:
self.__request_queue.put_nowait((r, r.data, False, None))
self.__request_queue.put_nowait(
(r, r.data, False, None, defaultdict(int))
)
else:
self.__stream.post(
request_reply_stream.Payload(
@ -1253,7 +1314,9 @@ class ModelWorker(worker_base.Worker):
if ack_id in self.__request_cache:
self.__ack_cache.pop(ack_id)
req = self.__request_cache.pop(ack_id)
self.__request_queue.put_nowait((req, req.data, False, None))
self.__request_queue.put_nowait(
(req, req.data, False, None, defaultdict(int))
)
def _poll(self):
if not self.__dist_env_resolved:
@ -1279,7 +1342,7 @@ class ModelWorker(worker_base.Worker):
# are executed in the same order across all model workers.
flush = False
for _ in range(self.__request_queue.qsize()):
request, data, handled, res = self.__request_queue.get_nowait()
request, data, handled, res, time_record = self.__request_queue.get_nowait()
if request.handle_name == "reset":
# Pause the worker and wait for the next `configure`
# command from the controller.
@ -1289,7 +1352,9 @@ class ModelWorker(worker_base.Worker):
elif request.handle_name in NON_BLOCKING_RPCS:
self.handle_non_blocking_request(request)
else:
self.__request_queue.put_nowait((request, data, handled, res))
self.__request_queue.put_nowait(
(request, data, handled, res, time_record)
)
# Non-blocking requests are usually fast, so we can
# respond them in a batch without affecting the accuracy
@ -1309,25 +1374,37 @@ class ModelWorker(worker_base.Worker):
rescheduled_requests = []
other_requests = []
for _ in range(self.__request_queue.qsize()):
request, data, handled, res = self.__request_queue.get_nowait()
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
if request.handle_name not in ["inference", "generate", "train_step"]:
other_requests.append((request, data, handled, res))
other_requests.append((request, data, handled, res, time_record))
else:
with constants.model_scope(request.handler.model_name):
w = dist.get_world_size(constants.parallelism_group())
rescheduled_requests.append((request, data, handled, res, w))
rescheduled_requests.append(
(request, data, handled, res, time_record, w)
)
rescheduled_requests.sort(key=lambda x: x[-1])
for request, data, handled, res, _ in rescheduled_requests:
self.__request_queue.put_nowait((request, data, handled, res))
for request, data, handled, res in other_requests:
self.__request_queue.put_nowait((request, data, handled, res))
for request, data, handled, res, time_record, _ in rescheduled_requests:
self.__request_queue.put_nowait(
(request, data, handled, res, time_record)
)
for request, data, handled, res, time_record in other_requests:
self.__request_queue.put_nowait(
(request, data, handled, res, time_record)
)
# Execute one MFC them immediately return the result, such that
# we can correctly log the time consumption in the master worker.
while True:
try:
request, data, handled, res = self.__request_queue.get_nowait()
self.handle_blocking_request(request, data, handled, res)
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
self.handle_blocking_request(
request, data, handled, res, time_record
)
r += self.maybe_post_responses()
except queue.Empty:
break

View File

@ -0,0 +1,274 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import time
from asyncio.queues import QueueEmpty
from collections import defaultdict
from dataclasses import asdict
from typing import Dict, Hashable, List
import aiohttp
from aiohttp.client import ClientTimeout
from transformers import PreTrainedTokenizerFast
from realhf.api.cli_args import GenerationHyperparameters
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
BundledGenerationOutputs,
GenReqMeta,
)
from realhf.base import constants, logging, name_resolve, names
logger = logging.getLogger(__name__)
GENERATION_POLL_WAIT_TIME = 0.05
class PartialRolloutManager:
"""Manages the partial rollout for a client.
It will submit generation requests in chunks, i.e.,
generating at most `new_tokens_per_chunk` tokens each time.
In this way, we can reduce the overhead of flushing all requests
upon model weights update.
This is a hack usage. We don't need it if the server can pause
requests, update weights, and recompute kv caches at any time.
"""
def __init__(
self,
worker_index: int,
request_queue: asyncio.Queue,
reply_queue: asyncio.Queue,
new_tokens_per_chunk: int,
tokenizer: PreTrainedTokenizerFast,
timeout: int,
):
self.worker_index = worker_index
# qid -> {group_idx -> aiohttp Task}
self.gen_requests: Dict[Hashable, Dict[int, asyncio.Task]]
self.gen_requests = defaultdict(dict)
# NOTE: Grouped generations are managed separately. Store early returned
# answers in this cache and pop the result when the whole group is done.
self.gen_cache: Dict[Hashable, Dict[int, APIGenerateOutput]]
self.gen_cache = defaultdict(dict)
self.tokenizer = tokenizer
self.request_queue = request_queue
self.reply_queue = reply_queue
self.new_tokens_per_chunk = new_tokens_per_chunk
self.gserver_manager_addr = None
self.timeout = timeout
async def _schedule_request(self, req_meta: GenReqMeta):
if self.gserver_manager_addr is None:
# Get the address of gserver manager to schedule requests
name = names.gen_server_manager(
constants.experiment_name(), constants.trial_name()
)
self.gserver_manager_addr = name_resolve.wait(name, timeout=300)
time.sleep(1) # Wait for the server to start
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://{self.gserver_manager_addr}/schedule_request",
json=asdict(req_meta),
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
) as response:
response.raise_for_status()
res = await response.json()
return res
def get_num_gen_requests(self):
return len(self.gen_requests)
async def _run_gen(
self,
url,
qid,
group_idx,
prompt_ids,
input_ids,
prev_logprobs,
version_start,
cur_server_version,
raw_gconfig,
):
from realhf.impl.model.backend.sglang import SGLangAPIClient
gconfig = raw_gconfig.new(
n=1,
max_new_tokens=min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk),
)
assert self.tokenizer.pad_token_id is not None
assert self.tokenizer.eos_token_id is not None
# Don't need to request updating weights
async with SGLangAPIClient(
generate_url=f"{url}/generate", update_weights_url=""
) as api_client:
res = await api_client.async_add_generate_request(
APIGenerateInput(
qid=qid,
prompt_ids=prompt_ids,
input_ids=input_ids,
gconfig=gconfig,
stop_token_ids=[
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id,
],
return_logprob=True,
version_start=version_start,
prev_logprobs=prev_logprobs,
metadata=dict(
group_idx=group_idx,
raw_gconfig=raw_gconfig,
server_url=url,
),
),
stream=False,
)
res.version_end = [cur_server_version for _ in range(res.group_size)]
return res
async def _issue_generation(
self,
url: str,
qid: Hashable,
group_idx: int,
prompt_ids: List[int],
input_ids: List[int],
prev_logprobs: List[float],
version_start: int,
raw_gconfig: GenerationHyperparameters,
cur_server_version: int,
):
"""Issue a generation request.
`input_ids` can be a partial prefix and longer than `prompt_ids`.
If model weights are updated, the KV cache will be refreshed,
otherwise the server will reuse the radix cache with no additional overhead.
"""
task = asyncio.create_task(
self._run_gen(
url,
qid,
group_idx,
prompt_ids,
input_ids,
prev_logprobs,
version_start=version_start,
cur_server_version=cur_server_version,
raw_gconfig=raw_gconfig,
)
)
self.gen_requests[qid][group_idx] = task
await asyncio.sleep(0)
async def refresh_generation(self):
tasks = []
for group_requests in self.gen_requests.values():
tasks += list(group_requests.values())
done = []
if tasks:
# No new checkpoint available, try to wait for the next complete sequence
done, _ = await asyncio.wait(
tasks,
timeout=GENERATION_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
s: APIGenerateOutput = await task
group_idx = s.metadata["group_idx"]
raw_gconfig = s.metadata["raw_gconfig"]
assert s.group_size == 1
no_eos = s.no_eos[0]
gen_len = s.gen_lens[0]
self.gen_requests[s.qid].pop(group_idx)
if len(self.gen_requests[s.qid]) == 0:
self.gen_requests.pop(s.qid)
if no_eos and gen_len < raw_gconfig.max_new_tokens:
# Unfinished request due to chunked generation.
# Send it back to continue.
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://{self.gserver_manager_addr}/get_model_version",
json=dict(server_url=s.metadata["server_url"]),
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
) as resp:
resp.raise_for_status()
cur_version = (await resp.json())["version"]
if len(s.output_logprobs) > 0:
prev_logprobs = s.prev_logprobs + s.output_logprobs[0]
else:
prev_logprobs = []
await self._issue_generation(
s.metadata["server_url"],
s.qid,
group_idx,
s.prompt_ids,
s.input_ids + s.output_ids[0],
version_start=s.version_start,
prev_logprobs=prev_logprobs,
raw_gconfig=raw_gconfig,
cur_server_version=cur_version,
)
else:
# Generation finishes. Save to cache for later fetching.
self.gen_cache[s.qid][group_idx] = s
if len(self.gen_cache[s.qid]) >= raw_gconfig.n:
gen_results = self.gen_cache.pop(s.qid)
output = BundledGenerationOutputs.from_api_outputs(
list(gen_results.values())
)
self.reply_queue.put_nowait(output)
async def poll_fresh_requests_task(self):
for _ in range(8):
try:
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
req_meta = GenReqMeta(
prompt_len=len(prompt_token_ids),
group_size=gconfig.n,
new_token_budget=self.new_tokens_per_chunk,
predicted_new_tokens=None,
)
dst_server_info = await self._schedule_request(req_meta)
for group_idx in range(gconfig.n):
await self._issue_generation(
dst_server_info["url"],
qid,
group_idx,
prompt_token_ids,
prompt_token_ids,
version_start=dst_server_info["version"],
prev_logprobs=[],
raw_gconfig=gconfig,
cur_server_version=dst_server_info["version"],
)
except QueueEmpty:
break
async def poll_old_requests_task(self):
for _ in range(8):
await self.refresh_generation()
async def run_step(self):
await asyncio.gather(
self.poll_fresh_requests_task(),
self.poll_old_requests_task(),
)

View File

@ -6,7 +6,7 @@ import orjson
import zmq
from zmq.utils.strtypes import asbytes
from realhf.base import logging
from realhf.base import logging, name_resolve, names, network
logger = logging.getLogger("ZMQ Push-Pull Stream")
@ -97,7 +97,7 @@ class ZMQJsonPuller:
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def pull(self, timeout_ms: Optional[int] = None) -> JSONType:
def pull(self, timeout_ms: Optional[int] = None):
"""
Pull and decode JSON data with configurable timeout.
@ -126,3 +126,52 @@ class ZMQJsonPuller:
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def grouping(num_senders, num_receivers):
groups = {}
assert num_senders >= num_receivers
# Each PULL gets multiple PUSH
senders_per_receiver = num_senders // num_receivers
for receiver_id in range(num_receivers):
start = receiver_id * senders_per_receiver
end = (receiver_id + 1) * senders_per_receiver
groups[receiver_id] = list(range(start, end))
# Distribute remaining senders
remaining = num_senders % num_receivers
for i in range(remaining):
groups[i].append(num_receivers * senders_per_receiver + i)
return groups
class NameResolvingZmqPusher(ZMQJsonPusher):
def __init__(self, experiment_name, trial_name, pusher_index, pusher_cnt, **kwargs):
pullers = name_resolve.get_subtree(
names.stream_pullers(experiment_name, trial_name)
)
pullers = list(map(int, pullers))
puller_cnt = len(pullers)
assert sorted(pullers) == list(range(puller_cnt))
groups = grouping(pusher_cnt, puller_cnt)
puller_index = None
for puller_index, pusher_indices in groups.items():
if pusher_index in pusher_indices:
break
assert puller_index is not None
name = names.push_pull_stream(
experiment_name, trial_name, stream_name=f"puller{puller_index}"
)
addr = name_resolve.wait(name)
host, port = addr.split(":")
super().__init__(host, int(port), **kwargs)
class NameResolvingZmqPuller(ZMQJsonPuller):
def __init__(self, experiment_name, trial_name, puller_index, **kwargs):
name = names.push_pull_stream(
experiment_name, trial_name, stream_name=f"puller{puller_index}"
)
host, port = network.gethostip(), network.find_free_port()
addr = f"{host}:{port}"
name_resolve.add(name, addr)
super().__init__(host, port, **kwargs)

View File

@ -0,0 +1,332 @@
import asyncio
import json
import os
import queue
import time
from asyncio.queues import QueueEmpty
from typing import Dict, Hashable, List
import aiohttp
import numpy as np
import torch.utils.data
from aiohttp.client import ClientTimeout
from realhf.api.core.agent_api import make_agent
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer, make_dataset
from realhf.api.core.env_api import make_env
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
from realhf.base import (
constants,
datapack,
logging,
name_resolve,
names,
recover,
seeding,
)
from realhf.system.partial_rollout import PartialRolloutManager
from realhf.system.push_pull_stream import NameResolvingZmqPusher
from realhf.system.worker_base import AsyncWorker, PollResult
# NOTE: Register all implemented agents
import realhf.impl.environment # isort:skip
import realhf.impl.agent # isort:skip
logger = logging.getLogger("RolloutWorker")
# Should be equal to the poll time of partial rollout
ROLLOUT_POLL_WAIT_TIME = 0.4
class RolloutWorker(AsyncWorker):
def _configure(self, config: RolloutWorkerConfig):
self.model_name = config.model_name
self.config = config
self.worker_index = config.worker_info.worker_index
self.worker_count = config.worker_info.worker_count
self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name
self.env = make_env(config.env)
self.agent = make_agent(config.agent)
self.rollout_request_queue = asyncio.Queue(1024)
self.rollout_response_queue = asyncio.Queue(1024)
self.act_queues = {}
self.rollout_tasks = {}
self.inference_maker = PartialRolloutManager(
worker_index=self.worker_index,
request_queue=self.rollout_request_queue,
reply_queue=self.rollout_response_queue,
new_tokens_per_chunk=config.new_tokens_per_chunk,
tokenizer=load_hf_tokenizer(config.tokenizer_path),
timeout=self.config.rollout_request_timeout,
)
self.push_stream = None
seeding.set_random_seed(
config.base_seed, f"rollout_worker{config.worker_info.worker_index}"
)
self.data_generator = None
self.is_new_epoch = False
self._cur_data = None
self.gserver_manager_addr = None
self.rollout_tasks: Dict[Hashable, asyncio.Task] = {}
# recover info
self.__recover_run, self.__recover_info = recover.load_recover_info()
return config.worker_info
def make_datasets(self):
# Make datasets.
datasets = [
make_dataset(
d,
# NOTE: we must use the same seed to ensure the same dataset split
self.config.base_seed,
self.worker_index,
self.worker_count,
self.config.tokenizer_path,
self.config.worker_info.experiment_name,
self.config.worker_info.trial_name,
cache_root=(
None
if not self.config.use_dataset_cache
else self.config.dataset_cahce_root
),
)
for d in self.config.datasets
]
if len(self.config.datasets) == 1:
self.dataset = datasets[0]
else:
self.dataset = torch.utils.data.ConcatDataset(datasets)
self.dataset_size = len(self.dataset)
g = torch.Generator()
g.manual_seed(seeding.get_seed())
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=None,
shuffle=True,
generator=g,
)
self.data_generator = enumerate(self.dataloader)
# Recover indices for dynamic dataset
if hasattr(self.dataset, "filter"):
dataset_indices_path = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{self.worker_index}.npy",
)
if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
logger.info(
f"DP rank {self.worker_index} updating dataset indices upon recover, "
f"size {len(self.dataset.active_indices)} -> {len(indices)}"
)
self.dataset.active_indices = indices
def load_next_data(self):
# Create an epoch-wise barrier to prevent data over-consumption.
if self.is_new_epoch:
if len(self.rollout_tasks) > 0:
return None
self.is_new_epoch = False
# Fetch.
try:
_, cur_sample = next(self.data_generator)
except StopIteration:
self.is_new_epoch = True
# Upon the first fetch request, filter dataset and create dataloader.
eval_scores_path = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
"dataset_eval_scores.json",
)
dataset_indices_path = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{self.worker_index}.npy",
)
if hasattr(self.dataset, "filter") and os.path.exists(eval_scores_path):
# Don't filter dataset on the first poll after recover.
with open(eval_scores_path, "r", encoding="utf-8") as f:
dataset_eval_scores = json.load(f)
self.dataset.filter(dataset_eval_scores)
# Save the dataset indices after filtering
np.save(
dataset_indices_path,
self.dataset.active_indices,
)
g = torch.Generator()
g = g.set_state(self.dataloader.generator.get_state())
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=None,
shuffle=True,
generator=g,
)
self.data_generator = enumerate(self.dataloader)
return None
data_id = cur_sample.ids[0]
if self.__recover_run and data_id in self.__recover_info.hash_vals_to_ignore:
self.__recover_info.hash_vals_to_ignore.remove(data_id)
return None
assert data_id not in self.rollout_tasks
return cur_sample
async def allocate_new_rollout(self) -> bool:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://{self.gserver_manager_addr}/allocate_rollout",
timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30
),
) as resp:
resp.raise_for_status()
res = await resp.json()
return res["success"]
async def _poll_async(self):
# Lazily initializing dataset to avoid over long configuration time.
if self.data_generator is None:
tik = time.perf_counter()
logger.info(f"Rollout worker {self.worker_index} making datasets..")
self.make_datasets()
logger.info(
f"Rollout worker {self.worker_index} finishes making datasets. "
f"Time consumed: {time.perf_counter() - tik}s"
)
if self.push_stream is None:
# Initialize stream after configure to ensure that puller names have been written.
self.push_stream = NameResolvingZmqPusher(
self.experiment_name,
self.trial_name,
pusher_index=self.worker_index,
pusher_cnt=self.worker_count,
)
if self.gserver_manager_addr is None:
name = names.gen_server_manager(self.experiment_name, self.trial_name)
self.gserver_manager_addr = name_resolve.wait(name)
# Create new trajectory collection tasks.
# Load only one data in each poll to avoid over consumption.
if self._cur_data is None:
self._cur_data = self.load_next_data()
if self._cur_data is not None:
can_rollout = await self.allocate_new_rollout()
if can_rollout:
data = self._cur_data
qid = data.ids[0]
self.act_queues[qid] = asyncio.Queue(1024)
task = asyncio.create_task(self.rollout_task(qid, data))
self.rollout_tasks[qid] = task
self._cur_data = None
# Run rollouts and wait
done, *_ = await asyncio.gather(
self.poll_rollout_task(),
self.poll_queue_dispatch_task(),
self.poll_inference_task(),
)
# Process done tasks.
batch_count = sample_count = 0
for task in done:
qid, trajs = await task
trajs: List[SequenceSample]
assert len(set(traj.ids[0] for traj in trajs)) == len(trajs), [
traj.ids[0] for traj in trajs
]
self.rollout_tasks.pop(qid)
self.act_queues.pop(qid)
accepted = False
if len(trajs) > 0:
accepted = True
self.push_stream.push([traj.as_json_compatible() for traj in trajs])
info = dict(qid=qid, accepted=accepted)
async with aiohttp.ClientSession(
f"http://{self.gserver_manager_addr}"
) as session:
async with session.post(
"/finish_rollout",
json=info,
timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30
),
) as resp:
resp.raise_for_status()
assert (await resp.json())["success"]
for traj in trajs:
batch_count += traj.bs
sample_count += max(
[sum(datapack.flat2d(slens)) for slens in traj.seqlens.values()]
)
return PollResult(batch_count, sample_count)
async def rollout_task(self, qid, data):
return qid, await self.agent.collect_trajectory(
env=self.env,
prompt=data,
act_queue=self.act_queues[qid],
obs_queue=self.rollout_request_queue,
)
async def poll_inference_task(self):
await self.inference_maker.run_step()
async def poll_rollout_task(self):
tasks = list(self.rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
return done
async def poll_queue_dispatch_task(self):
for _ in range(20):
try:
resp = self.rollout_response_queue.get_nowait()
self.act_queues[resp.qid].put_nowait(resp)
except QueueEmpty:
await asyncio.sleep(0.02)
async def _exit_async_tasks(self):
for task in self.rollout_tasks.values():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def _exit_hook(self, exit_status):
if self.push_stream is not None:
self.push_stream.close()
loop = asyncio.get_event_loop()
loop.run_until_complete(self._exit_async_tasks())

View File

@ -0,0 +1,100 @@
import queue
import threading
import time
from typing import Any, List, Optional
from torch.utils.data import ConcatDataset, Dataset
from realhf.api.core.config import DatasetAbstraction
from realhf.api.core.data_api import (
DatasetUtility,
SequenceSample,
make_dataset,
register_dataset,
)
from realhf.base import constants
from realhf.system.push_pull_stream import NameResolvingZmqPuller
class PullerStreamDataset(Dataset):
def __init__(
self,
util: DatasetUtility,
dataset_cfgs: List[DatasetAbstraction],
pull_timeout_ms=100,
):
# This dataset is just used for computing the dataset size,
# and the number of steps per epoch.
datasets = [
make_dataset(
dataset_cfg,
seed=util.seed,
dp_rank=util.dp_rank,
world_size=util.world_size,
tokenizer_or_tokenizer_name=util.tokenizer,
experiment_name=constants.experiment_name(),
trial_name=constants.trial_name(),
)
for dataset_cfg in dataset_cfgs
]
if len(datasets) == 1:
dataset = datasets[0]
else:
dataset = ConcatDataset(datasets)
self.dataset_size = len(dataset)
del dataset, datasets
self.pull_timeout_ms = pull_timeout_ms
self.data_queue = queue.Queue(maxsize=self.dataset_size)
self._stop_event = threading.Event()
# Pass ZMQ context (thread-safe) and let worker create the socket
self.util = util
self.worker_thread = threading.Thread(
target=self._pull_data_worker,
daemon=True,
)
self.worker_thread.start()
def _pull_data_worker(self):
"""Worker thread that creates its own ZMQ puller and streams data."""
# Initialize the puller inside the worker thread
stream = NameResolvingZmqPuller(
constants.experiment_name(),
constants.trial_name(),
puller_index=self.util.dp_rank,
)
try:
while not self._stop_event.is_set():
try:
data = stream.pull(timeout_ms=self.pull_timeout_ms)
processed_data = [
SequenceSample.from_json_compatible(x) for x in data
]
self.data_queue.put_nowait(processed_data)
except queue.Empty:
time.sleep(0.1)
continue
finally:
# Ensure socket is closed in the same thread
del stream
def __getitem__(self, idx: int) -> Optional[Any]:
samples = []
while True:
try:
samples += self.data_queue.get_nowait()
except queue.Empty:
break
return samples
def __len__(self) -> int:
return self.dataset_size
def __del__(self):
self._stop_event.set()
if self.worker_thread.is_alive():
self.worker_thread.join(timeout=1.0)
register_dataset("puller_stream", PullerStreamDataset)

56
realhf/version.py Normal file
View File

@ -0,0 +1,56 @@
import subprocess
from pathlib import Path
__version__ = "0.3.0-dev"
__branch__ = ""
__commit__ = ""
__is_dirty__ = False
try:
__branch__ = (
subprocess.check_output(
["git", "branch", "--show-current"],
stderr=subprocess.DEVNULL,
cwd=Path(__file__).parent,
)
.decode("utf-8")
.strip()
)
__commit__ = (
subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"],
stderr=subprocess.DEVNULL,
cwd=Path(__file__).parent,
)
.decode("utf-8")
.strip()
)
__is_dirty__ = False
try:
subprocess.check_call(
["git", "diff-index", "--quiet", "HEAD", "--"],
stderr=subprocess.DEVNULL,
cwd=Path(__file__).parent,
)
except subprocess.CalledProcessError:
__is_dirty__ = True
except (subprocess.CalledProcessError, FileNotFoundError):
pass
def get_full_version() -> str:
version = __version__
if __commit__ != "":
version = f"{__version__}-{__commit__}"
if __is_dirty__:
version = f"{version}-dirty"
return version
def get_full_version_with_dirty_description() -> str:
version = get_full_version()
if __is_dirty__:
version = (
f"{version} ('-dirty' means there are uncommitted code changes in git)"
)
return version

View File

@ -52,9 +52,10 @@ torch>2.0.0
black==25.1.0
cookiecutter>2.1.1
asyncio
aiohttp
aiohttp>=3.11.10
httpx>=0.28.1
etcd3
protobuf<3.21
rich
orjson>=3.10.16
flask

View File

@ -0,0 +1,192 @@
import asyncio
import json
import os
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from realhf.api.core.agent_api import Agent
from realhf.api.core.model_api import BundledGenerationOutputs
from realhf.base import constants, name_resolve, testing
@pytest.fixture
def mock_env():
env = AsyncMock()
env.reset = AsyncMock()
env.step = AsyncMock(return_value=(None, [0.5, 0.7], None))
return env
@pytest.fixture
def agent_config():
return {
"gconfig": MagicMock(n=2),
"tokenizer_path": "/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/",
"success_rate_lb": 0.1,
"success_rate_ub": 1.0,
"reward_scaling": 2.0,
"reward_bias": 0.1,
}
@pytest.fixture
def agent(agent_config):
from realhf.impl.agent.math_single_step_agent import MathSingleStepAgent
testing.clear_name_resolve()
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
agent = MathSingleStepAgent(**agent_config)
yield agent
@pytest.fixture
def mock_prompt():
from realhf.api.core import data_api
return data_api.SequenceSample(
ids=[str(123)],
data={"packed_prompts": torch.tensor([1, 2, 3])},
keys=set(["packed_prompts"]),
seqlens=dict(packed_prompts=[[3]]),
dtypes=dict(packed_prompts=torch.long),
trailing_shapes=dict(packed_prompts=()),
)
@pytest.fixture
def mock_act():
return BundledGenerationOutputs(
qid=str(123),
seqs=[[1, 2, 3, 4, 5, 6], [1, 2, 3, 7, 8, 9]],
output_ids=[[4, 5, 6], [7, 8, 9]],
prompt_ids=[1, 2, 3],
logprobs=[[0, 0, -0.1, -0.2, -0.3], [0, 0, -0.3, -0.2, -0.3]],
no_eos=[True, False],
version_start=[1, 1],
version_end=[2, 2],
)
@pytest.mark.asyncio
async def test_collect_trajectory_happy_path(agent, mock_env, mock_prompt, mock_act):
obs_queue = asyncio.Queue()
act_queue = asyncio.Queue()
await act_queue.put(mock_act)
result = await agent.collect_trajectory(mock_prompt, mock_env, obs_queue, act_queue)
assert len(result) == 1
sample = result[0]
assert sample.ids == [str(123)]
assert torch.equal(sample.data["packed_prompts"], torch.tensor([1, 2, 3]))
assert torch.equal(sample.data["rewards"], torch.tensor([0.8, 1.2]))
@pytest.mark.asyncio
async def test_collect_trajectory_low_reward(
agent_config, mock_env, mock_prompt, mock_act
):
# Set reward lower bound higher than what env will return
agent_config["success_rate_lb"] = 1.0
from realhf.impl.agent.math_single_step_agent import MathSingleStepAgent
agent = MathSingleStepAgent(**agent_config)
obs_queue = asyncio.Queue()
act_queue = asyncio.Queue()
await act_queue.put(mock_act)
result = await agent.collect_trajectory(mock_prompt, mock_env, obs_queue, act_queue)
assert len(result) == 0
@pytest.mark.asyncio
async def test_collect_trajectory_high_reward(
agent_config, mock_env, mock_prompt, mock_act
):
# Set reward upper bound lower than what env will return
agent_config["success_rate_ub"] = 0.0
from realhf.impl.agent.math_single_step_agent import MathSingleStepAgent
agent = MathSingleStepAgent(**agent_config)
obs_queue = asyncio.Queue()
act_queue = asyncio.Queue()
await act_queue.put(mock_act)
result = await agent.collect_trajectory(mock_prompt, mock_env, obs_queue, act_queue)
assert len(result) == 0
@pytest.mark.asyncio
async def test_collect_trajectory_empty_act_queue(agent, mock_env, mock_prompt):
obs_queue = asyncio.Queue()
act_queue = asyncio.Queue()
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(
agent.collect_trajectory(mock_prompt, mock_env, obs_queue, act_queue),
timeout=1,
)
def test_log_rewards_to_file(agent, tmp_path):
# Setup test directories
with (
patch("realhf.base.constants.LOG_ROOT", tmp_path),
patch("realhf.base.constants.experiment_name", return_value="test_exp"),
patch("realhf.base.constants.trial_name", return_value="test_trial"),
):
agent.log_rewards_to_file(
qid="123",
prompt="test_prompt",
prompt_len=3,
answers=["answer1", "answer2"],
seqlens=[5, 6],
rewards=[0.5, 0.7],
success=[True, False],
version_starts=[1, 2],
version_ends=[2, 3],
)
# Check generated file
gen_file_path = (
tmp_path / "test_exp" / "test_trial" / "generated" / "1" / "123.txt"
)
assert gen_file_path.exists()
with open(gen_file_path) as f:
content = f.read()
assert "idx: 1 / 2" in content
assert "seqlen: 5" in content
assert "test_prompt" in content
# Check monitor file
monitor_file_path = (
tmp_path
/ "test_exp"
/ "test_trial"
/ "training_monitor"
/ "1"
/ "123.jsonl"
)
assert monitor_file_path.exists()
with open(monitor_file_path) as f:
data = json.loads(f.readline())
assert data["version_start"] == 1
assert data["prompt_len"] == 3
def test_reward_calculation(agent):
# Test reward scaling and biasing
raw_rewards = [0.2, 0.4]
expected = [(0.2 - 0.1) * 2.0, (0.4 - 0.1) * 2.0]
processed = [
(float(r) - agent.reward_bias) * agent.reward_scaling for r in raw_rewards
]
assert processed == expected

View File

@ -0,0 +1,63 @@
import pytest
import torch
from realhf.impl.model.utils.ppo_functional import actor_loss_fn
# Copied from https://github.com/opendilab/PPOxFamily/blob/main/chapter7_tricks/dual_clip.py
def ppo_dual_clip(
logp_new: torch.FloatTensor,
logp_old: torch.FloatTensor,
adv: torch.FloatTensor,
clip_ratio: float,
dual_clip: float,
) -> torch.FloatTensor:
"""
**Overview**:
This function implements the Proximal Policy Optimization (PPO) policy loss with dual-clip
mechanism, which is a variant of PPO that provides more reliable and stable training by
limiting the updates to the policy, preventing it from deviating too much from its previous versions.
Arguments:
- logp_new (:obj:`torch.FloatTensor`): The log probability calculated by the new policy.
- logp_old (:obj:`torch.FloatTensor`): The log probability calculated by the old policy.
- adv (:obj:`torch.FloatTensor`): The advantage value, which measures how much better an
action is compared to the average action at that state.
- clip_ratio (:obj:`float`): The clipping ratio used to limit the change of policy during an update.
- dual_clip (:obj:`float`): The dual clipping ratio used to further limit the change of policy during an update.
Returns:
- policy_loss (:obj:`torch.FloatTensor`): The calculated policy loss, which is the objective we
want to minimize for improving the policy.
"""
assert (
dual_clip is None or dual_clip > 1.0
), "Dual_clip value must be greater than 1.0, but get value: {}".format(dual_clip)
# This is the ratio of the new policy probability to the old policy probability.
# $$r(\theta) = \frac{\pi_{new}(a|s)}{\pi_{old}(a|s)}$$
ratio = torch.exp(logp_new - logp_old)
# The first clipping operation is performed here, we limit the update to be within a certain range.
# $$clip_1 = min(r(\theta)*A(s,a), clip(r(\theta), 1-clip\_ratio, 1+clip\_ratio)*A(s,a))$$
surr1 = ratio * adv
surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
clip1 = torch.min(surr1, surr2)
# The second clipping operation is performed here, we further limit the update to be within a stricter range.
# $$clip_2 = max(clip_1, dual\_clip * A(s,a))$$
if dual_clip is not None:
clip2 = torch.max(clip1, dual_clip * adv)
# We only apply the dual-clip when the advantage is negative, i.e., when the action is worse than the average.
policy_loss = -(torch.where(adv < 0, clip2, clip1)).mean()
else:
policy_loss = -clip1.mean()
return policy_loss
@pytest.mark.parametrize("eps_clip", [0.01, 0.05, 0.1, 0.2, 0.5])
@pytest.mark.parametrize("c_clip", [None, 1.5, 2.0, 5.0])
@pytest.mark.parametrize("size", [(1,), (10,), (100,)])
def test_dual_clip_acc(size, eps_clip, c_clip):
old_logp = -torch.randn(size, dtype=torch.float32).abs()
new_logp = -torch.randn(size, dtype=torch.float32).abs()
adv = torch.randn(size, dtype=torch.float32)
loss1 = ppo_dual_clip(new_logp, old_logp, adv, eps_clip, c_clip)
loss2, _ = actor_loss_fn(new_logp, old_logp, adv, eps_clip, c_clip=c_clip)
assert torch.allclose(loss1, loss2)

View File

@ -249,6 +249,13 @@ def test_gather_split(sample_type: str, dp: int):
for s1, s2 in zip(samples, ss):
recursive_assert_equal(s1, s2)
# Test json serialize
import orjson
bytes = orjson.dumps(x.as_json_compatible())
y = SequenceSample.from_json_compatible(orjson.loads(bytes))
recursive_assert_equal(x, y)
# Test split to the finest granularity
total_bs = sum(batch_sizes)
ss, _, backward_indices = x.split(MicroBatchSpec(n_mbs=x.bs))

View File

@ -0,0 +1,235 @@
import pytest
import torch
from realhf.base.stats_tracker import DistributedStatsTracker, ReduceType
@pytest.fixture
def tracker():
return DistributedStatsTracker()
def test_basic_stat_recording(tracker):
# Test basic stat recording and averaging
mask = torch.BoolTensor([True, False, True])
values = torch.FloatTensor([1.0, 2.0, 3.0])
tracker.denominator(mask=mask)
tracker.stat(denominator="mask", value=values)
results = tracker.export()
assert pytest.approx(results["value/avg"]) == 2.0 # (1+3)/2
assert pytest.approx(results["value/min"]) == 1.0
assert pytest.approx(results["value/max"]) == 3.0
def test_scoping(tracker):
# Test hierarchical scoping
with tracker.scope("parent"):
tracker.denominator(parent_mask=torch.BoolTensor([True]))
tracker.stat(parent_value=torch.FloatTensor([1.0]), denominator="parent_mask")
with tracker.scope("child"):
with pytest.raises(ValueError):
tracker.stat(denominator="child_mask", value=torch.FloatTensor([1.0]))
tracker.denominator(child_mask=torch.BoolTensor([1.0]))
tracker.stat(denominator="child_mask", value=torch.FloatTensor([1.0]))
results = tracker.export()
assert "parent/parent_mask" in results
assert "parent/parent_value/avg" in results
assert "parent/child/child_mask" in results
assert "parent/child/value/avg" in results
def test_reduce_types(tracker):
# Test different reduce types
mask = torch.BoolTensor([True, False, True])
values = torch.FloatTensor([1.0, 2.0, 3.0])
tracker.denominator(mask=mask)
tracker.stat(denominator="mask", reduce_type=ReduceType.SUM, sum_val=values)
tracker.stat(denominator="mask", reduce_type=ReduceType.MIN, min_val=values)
tracker.stat(denominator="mask", reduce_type=ReduceType.MAX, max_val=values)
results = tracker.export()
assert pytest.approx(results["sum_val"]) == 4.0 # 1+3
assert pytest.approx(results["min_val"]) == 1.0
assert pytest.approx(results["max_val"]) == 3.0
def test_validation_checks(tracker):
# Test input validation
with pytest.raises(ValueError):
tracker.denominator(invalid=torch.FloatTensor([1.0])) # Not bool tensor
tracker.denominator(mask=torch.BoolTensor([True]))
with pytest.raises(ValueError):
tracker.stat(denominator="nonexistent", value=torch.FloatTensor([1.0]))
with pytest.raises(AssertionError):
tracker.stat(
denominator="mask", value=torch.FloatTensor([1.0, 2.0]) # Shape mismatch
)
def test_multiple_recordings(tracker):
# Test multiple recordings
mask1 = torch.BoolTensor([True, False])
mask2 = torch.BoolTensor([False, True])
values1 = torch.FloatTensor([1.0, 2.0])
values2 = torch.FloatTensor([3.0, 4.0])
tracker.denominator(mask=mask1)
tracker.denominator(mask=mask2)
tracker.stat(denominator="mask", value=values1)
tracker.stat(denominator="mask", value=values2)
results = tracker.export()
assert (
pytest.approx(results["value/avg"]) == (1.0 + 4.0) / 2
) # (1 from 1st, 4 from 2nd)
def test_denominator_edge_cases(tracker):
# Test edge cases with denominators
with pytest.raises(ValueError): # Should fail on shape check
empty_mask = torch.BoolTensor([])
tracker.denominator(mask=empty_mask)
zero_mask = torch.BoolTensor([False, False])
tracker.denominator(mask=zero_mask)
tracker.stat(denominator="mask", value=torch.FloatTensor([1.0, 2.0]))
results = tracker.export()
assert torch.isnan(torch.tensor(results["value/min"])) # Should be inf
assert torch.isnan(torch.tensor(results["value/max"])) # Should be -inf
assert results["value/avg"] == 0.0
def test_key_specific_export(tracker):
# Test exporting specific keys
tracker.denominator(mask1=torch.BoolTensor([True]), mask2=torch.BoolTensor([True]))
tracker.stat(denominator="mask1", value1=torch.FloatTensor([1.0]))
tracker.stat(denominator="mask2", value2=torch.FloatTensor([2.0]))
result = tracker.export(key="value1")
assert "value1/avg" in result
assert "value2/avg" not in result
def test_scalar_values(tracker):
# Test scalar value recording and averaging
tracker.scalar(scalar1=1.0, scalar2=2.0)
tracker.scalar(scalar1=3.0, scalar2=4.0)
results = tracker.export()
assert pytest.approx(results["scalar1"]) == 2.0 # (1+3)/2
assert pytest.approx(results["scalar2"]) == 3.0 # (2+4)/2
def test_moe_aux_losses(monkeypatch, tracker):
# Test MOE auxiliary losses handling
from realhf.base.stats_tracker import MOE_AUX_LOSSES
# Mock distributed environment
monkeypatch.setattr("torch.distributed.is_initialized", lambda: True)
monkeypatch.setattr("torch.distributed.all_reduce", lambda x, group: x)
# Mock pipe parallel group and last stage check
mock_group = object()
monkeypatch.setattr("realhf.base.constants.pipe_parallel_group", lambda: mock_group)
monkeypatch.setattr("realhf.base.constants.is_last_pipe_stage", lambda: True)
# Set up test MOE losses
MOE_AUX_LOSSES["moe_loss1"] = torch.tensor([1.0, 2.0])
MOE_AUX_LOSSES["moe_loss2"] = torch.tensor([3.0, 4.0])
results = tracker.export()
assert pytest.approx(results["moe_loss1"]) == 1.5 # (1+2)/2
assert pytest.approx(results["moe_loss2"]) == 3.5 # (3+4)/2
assert not MOE_AUX_LOSSES # Should be cleared after export
def test_empty_tracker(tracker):
# Test exporting from an empty tracker
results = tracker.export()
assert results == {}
def test_reset_behavior(tracker):
# Test that stats are reset after export
tracker.denominator(mask=torch.BoolTensor([True]))
tracker.stat(denominator="mask", value=torch.FloatTensor([1.0]))
results1 = tracker.export()
assert "value/avg" in results1
results2 = tracker.export()
assert results2 == {}
def test_no_reset_behavior(tracker):
# Test that stats are preserved when reset=False
tracker.denominator(mask=torch.BoolTensor([True]))
tracker.stat(denominator="mask", value=torch.FloatTensor([1.0]))
results1 = tracker.export(reset=False)
assert "value/avg" in results1
results2 = tracker.export()
assert "value/avg" in results2 # Should still be there
def test_default_tracker():
# Test the default tracker instance and its functions
mask = torch.BoolTensor([True, False, True])
values = torch.FloatTensor([1.0, 2.0, 3.0])
from realhf.base.stats_tracker import denominator, export, stat
denominator(mask=mask)
stat(denominator="mask", value=values)
results = export()
assert pytest.approx(results["value/avg"]) == 2.0
def test_reduce_type_validation(tracker):
# Test invalid reduce type handling
with pytest.raises(ValueError):
tracker._set_reduce_type("key", "invalid_type") # Not a ReduceType enum
with pytest.raises(ValueError):
tracker.denominator(mask=torch.BoolTensor([True]))
tracker.stat(
denominator="mask", reduce_type="invalid", value=torch.FloatTensor([1.0])
)
def test_scalar_reduce_type_validation(tracker):
# Test that SCALAR reduce type can't be used with tensors
tracker.denominator(mask=torch.BoolTensor([True]))
with pytest.raises(ValueError):
tracker.stat(
denominator="mask",
reduce_type=ReduceType.SCALAR,
value=torch.FloatTensor([1.0]),
)
def test_full_key_generation(tracker):
# Test full key generation with and without scope
assert tracker._get_full_key("key") == "key"
with tracker.scope("scope1"):
assert tracker._get_full_key("key") == "scope1/key"
with tracker.scope("scope2"):
assert tracker._get_full_key("key") == "scope1/scope2/key"
# Test with empty name in constructor
empty_tracker = DistributedStatsTracker(name="")
assert empty_tracker._get_full_key("key") == "key"
# Test with name in constructor
named_tracker = DistributedStatsTracker(name="root")
assert named_tracker._get_full_key("key") == "root/key"

View File

@ -1,151 +0,0 @@
import os
import time
import etcd3
import pytest
from realhf.base.name_resolve import (
Etcd3NameRecordRepository,
NameEntryExistsError,
NameEntryNotFoundError,
)
host, port = os.getenv("REAL_ETCD_ADDR", "localhost:2379").split(":")
port = int(port)
@pytest.fixture
def etcd_client():
client = etcd3.client(host=host, port=port)
yield client
# Clean up etcd after each test
client.delete_prefix("test_") # Delete all keys
# Fixture to provide an instance of Etcd3NameRecordRepository
@pytest.fixture
def etcd_repo():
repo = Etcd3NameRecordRepository(host=host, port=port)
yield repo
repo.reset() # Clean up repository after each test
def test_add(etcd_repo):
# Test adding a new key-value pair
etcd_repo.add("test_key", "test_value")
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "test_value"
# Test adding a key that already exists without replace
with pytest.raises(NameEntryExistsError):
etcd_repo.add("test_key", "new_value", replace=False)
# Test adding a key that already exists with replace
etcd_repo.add("test_key", "new_value", replace=True)
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "new_value"
def test_delete(etcd_repo):
# Test deleting an existing key
etcd_repo.add("test_key", "test_value")
etcd_repo.delete("test_key")
value, _ = etcd_repo._client.get("test_key")
assert value is None
# Test deleting a non-existent key
with pytest.raises(NameEntryNotFoundError):
etcd_repo.delete("non_existent_key")
def test_clear_subtree(etcd_repo):
# Test clearing a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
etcd_repo.clear_subtree("test_key")
value1, _ = etcd_repo._client.get("test_key/sub1")
value2, _ = etcd_repo._client.get("test_key/sub2")
assert value1 is None
assert value2 is None
def test_get(etcd_repo):
# Test getting an existing key
etcd_repo.add("test_key", "test_value")
assert etcd_repo.get("test_key") == "test_value"
# Test getting a non-existent key
with pytest.raises(NameEntryNotFoundError):
etcd_repo.get("non_existent_key")
def test_get_subtree(etcd_repo):
# Test getting values from a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
assert etcd_repo.get_subtree("test_key") == ["value1", "value2"]
def test_find_subtree(etcd_repo):
# Test finding keys in a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
assert etcd_repo.find_subtree("test_key") == ["test_key/sub1", "test_key/sub2"]
def test_reset(etcd_repo):
# Test resetting the repository
etcd_repo.add("test_key1", "value1", delete_on_exit=True)
etcd_repo.add("test_key2", "value2", delete_on_exit=True)
etcd_repo.reset()
value1, _ = etcd_repo._client.get("test_key1")
value2, _ = etcd_repo._client.get("test_key2")
assert value1 is None
assert value2 is None
def test_watch_names(etcd_repo):
# Test watching keys
callback_called = False
def callback():
nonlocal callback_called
callback_called = True
etcd_repo.add("test_key", "test_value")
etcd_repo.watch_names(["test_key"], callback)
# Delete the key to trigger the callback
etcd_repo.delete("test_key")
time.sleep(1) # Give the watcher time to trigger
assert callback_called
def test_keepalive_thread(etcd_repo):
# Test the keepalive thread
etcd_repo.add("test_key", "test_value", keepalive_ttl=2)
time.sleep(1) # Wait for the keepalive thread to refresh the lease
# Ensure the key still exists
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "test_value"
time.sleep(2) # Wait for the lease to expire
with pytest.raises(NameEntryNotFoundError):
etcd_repo.get("test_key")
def test_context_manager(etcd_repo):
# Test the context manager
with etcd_repo as repo:
repo.add("test_key", "test_value", delete_on_exit=True)
assert repo.get("test_key") == "test_value"
# Ensure the key is deleted after exiting the context
value, _ = etcd_repo._client.get("test_key")
assert value is None
def test_del(etcd_repo, etcd_client):
# Test the destructor
etcd_repo.add("test_key", "test_value", delete_on_exit=True)
etcd_repo.__del__()
value, _ = etcd_client.get("test_key")
assert value is None

View File

@ -0,0 +1,681 @@
import os
import shutil
import tempfile
import threading
import time
import uuid
from unittest.mock import MagicMock, patch
import pytest
from realhf.base.name_resolve import (
Etcd3NameRecordRepository,
NameEntryExistsError,
NameEntryNotFoundError,
NfsNameRecordRepository,
)
# Define backend configurations for parameterized tests
BACKENDS = [
("memory", {}),
("nfs", {}),
]
if os.environ.get("REAL_ETCD_ADDR"):
BACKENDS.append(
(
"etcd3",
{
"host": os.getenv("REAL_ETCD_ADDR").split(":")[0],
"port": int(os.getenv("REAL_ETCD_ADDR").split(":")[1]),
},
)
)
@pytest.fixture(params=BACKENDS, ids=[b[0] for b in BACKENDS])
def name_resolve(request):
"""Fixture that provides a name resolve repository for each backend type."""
backend_type, kwargs = request.param
# Special handling for NFS backend to use temp directory
if backend_type == "nfs":
temp_dir = tempfile.mkdtemp()
from realhf.base.name_resolve import NfsNameRecordRepository
original_root = NfsNameRecordRepository.RECORD_ROOT
NfsNameRecordRepository.RECORD_ROOT = temp_dir
repo = NfsNameRecordRepository()
yield repo
repo.reset()
NfsNameRecordRepository.RECORD_ROOT = original_root
shutil.rmtree(temp_dir)
elif backend_type == "memory":
from realhf.base.name_resolve import MemoryNameRecordRepository
repo = MemoryNameRecordRepository()
yield repo
repo.reset()
elif backend_type == "etcd3":
from realhf.base.name_resolve import Etcd3NameRecordRepository
repo = Etcd3NameRecordRepository(**kwargs)
yield repo
repo.reset()
def test_basic_add_get(name_resolve):
"""Test basic add and get functionality."""
# Add a new entry
name_resolve.add("test_key", "test_value")
assert name_resolve.get("test_key") == "test_value"
# Test with non-string value (should be converted to string)
name_resolve.add("test_key_int", 123, replace=True)
assert name_resolve.get("test_key_int") == "123"
def test_add_with_replace(name_resolve):
"""Test add operation with replace flag."""
name_resolve.add("test_key", "initial_value")
# Should fail when replace=False
with pytest.raises(NameEntryExistsError):
name_resolve.add("test_key", "new_value", replace=False)
# Should succeed when replace=True
name_resolve.add("test_key", "new_value", replace=True)
assert name_resolve.get("test_key") == "new_value"
def test_delete(name_resolve):
"""Test delete operation."""
name_resolve.add("test_key", "test_value")
name_resolve.delete("test_key")
# Verify deletion
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("test_key")
# Deleting non-existent key should raise
with pytest.raises(NameEntryNotFoundError):
name_resolve.delete("non_existent_key")
def test_clear_subtree(name_resolve):
"""Test clearing a subtree of keys."""
# Create a subtree of keys
name_resolve.add("test_root/key1", "value1")
name_resolve.add("test_root/key2", "value2")
name_resolve.add("test_root/sub/key3", "value3")
name_resolve.add("other_root/key", "value")
# Clear the subtree
name_resolve.clear_subtree("test_root")
# Verify subtree is gone
assert name_resolve.get_subtree("test_root") == []
assert name_resolve.find_subtree("test_root") == []
# Verify other tree remains
assert name_resolve.get("other_root/key") == "value"
def test_get_subtree(name_resolve):
"""Test retrieving values from a subtree."""
name_resolve.add("test_root/key1", "value1")
name_resolve.add("test_root/key2", "value2")
name_resolve.add("test_root/sub/key3", "value3")
values = name_resolve.get_subtree("test_root")
assert set(values) == {"value1", "value2", "value3"}
def test_find_subtree(name_resolve):
"""Test finding keys in a subtree."""
name_resolve.add("test_root/key1", "value1")
name_resolve.add("test_root/key2", "value2")
name_resolve.add("test_root/sub/key3", "value3")
keys = name_resolve.find_subtree("test_root")
assert set(keys) == {"test_root/key1", "test_root/key2", "test_root/sub/key3"}
assert keys == sorted(keys) # Should be sorted
def test_add_subentry(name_resolve):
"""Test adding subentries with automatic UUID generation."""
sub_name = name_resolve.add_subentry("test_root", "sub_value")
assert sub_name.startswith("test_root/")
assert len(sub_name.split("/")[-1]) == 8 # UUID part should be 8 chars
assert name_resolve.get(sub_name) == "sub_value"
def test_wait(name_resolve):
"""Test waiting for a key to appear."""
def delayed_add():
time.sleep(0.1)
name_resolve.add("test_key", "test_value")
thread = threading.Thread(target=delayed_add, daemon=True)
thread.start()
# Should return once key is added
assert name_resolve.wait("test_key", timeout=2) == "test_value"
thread.join()
# Test timeout
with pytest.raises(TimeoutError):
name_resolve.wait("non_existent_key", timeout=0.1)
def test_watch_names(name_resolve):
"""Test watching keys for changes."""
callback_called = False
def callback():
nonlocal callback_called
callback_called = True
name_resolve.add("test_key", "test_value")
name_resolve.watch_names("test_key", callback, poll_frequency=0.1)
# Delete the key to trigger callback
time.sleep(0.1) # Ensure watcher is ready
name_resolve.delete("test_key")
# Wait for callback
time.sleep(2)
assert callback_called
def test_reset(name_resolve):
"""Test reset functionality (cleanup of delete_on_exit keys)."""
name_resolve.add("test_key1", "value1", delete_on_exit=True)
name_resolve.add("test_key_no_delete", "value2", delete_on_exit=False)
name_resolve.reset()
# Only delete_on_exit=True keys should be removed
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("test_key1")
assert name_resolve.get("test_key_no_delete") == "value2"
name_resolve.delete("test_key_no_delete")
def test_context_manager(name_resolve):
"""Test context manager functionality."""
with name_resolve.__class__() as repo:
repo.add("test_key", "test_value", delete_on_exit=True)
assert repo.get("test_key") == "test_value"
# Key should be deleted after context exits
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("test_key")
def test_concurrent_access(name_resolve):
"""Test concurrent access to the same key."""
name_resolve.add("test_key", "initial_value")
def modify_value():
for i in range(5):
current = name_resolve.get("test_key")
name_resolve.add(
"test_key", f"modified_{threading.get_ident()}_{i}", replace=True
)
time.sleep(0.01)
threads = [threading.Thread(target=modify_value) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Final value should be one of the modified values
final_value = name_resolve.get("test_key")
assert "modified_" in final_value
def test_path_normalization(name_resolve):
"""Test handling of different path formats."""
# Test paths with trailing slashes
name_resolve.add("test_path/", "value1")
assert name_resolve.get("test_path") == "value1"
# with pytest.raises(NameEntryNotFoundError):
assert name_resolve.get("test_path/") == "value1"
# Test paths with double slashes
name_resolve.add("test//path", "value2")
assert name_resolve.get("test//path") == "value2"
# Test relative paths
with pytest.raises(NameEntryExistsError):
name_resolve.add("./test_path", "value3")
name_resolve.add("./test_path", "value3", replace=True)
assert name_resolve.get("./test_path") == "value3"
def test_add_with_invalid_inputs(name_resolve):
"""Test add method with invalid inputs."""
# Test with None name
with pytest.raises(
Exception
): # The specific exception type may vary by implementation
name_resolve.add(None, "value")
# Test with empty name
with pytest.raises(Exception):
name_resolve.add("", "value")
# Test with None value
name_resolve.add("test_key", None)
assert name_resolve.get("test_key") == "None"
def test_long_paths_and_values(name_resolve):
"""Test behavior with very long path names and values."""
long_name = "a/" * 100 + "key"
long_value = "x" * 10000
name_resolve.add(long_name, long_value)
assert name_resolve.get(long_name) == long_value
def test_special_characters(name_resolve):
"""Test handling of special characters in names and values."""
special_chars = "!@#$%^&*()_+-=[]{}|;:'\",.<>?`~ "
# Test special characters in name
for char in special_chars:
try:
name = f"test{char}key"
name_resolve.add(name, "value")
assert name_resolve.get(name) == "value"
name_resolve.delete(name)
except Exception as e:
print(f"Failed with character '{char}': {e}")
# Test special characters in value
for char in special_chars:
value = f"test{char}value"
name_resolve.add(f"key_{char}", value)
assert name_resolve.get(f"key_{char}") == value
def test_unicode_support(name_resolve):
"""Test support for Unicode characters in names and values."""
unicode_name = "测试/键"
unicode_value = "价值"
name_resolve.add(unicode_name, unicode_value)
assert name_resolve.get(unicode_name) == unicode_value
def test_stress_concurrent_add_get_delete(name_resolve):
"""Stress test with many concurrent operations."""
from concurrent.futures import ThreadPoolExecutor
num_threads = 20
ops_per_thread = 50
# Track success/failure counts
results = {
"success": 0,
"failures": 0,
}
def worker(thread_id):
try:
for i in range(ops_per_thread):
key = f"concurrent_key_{thread_id}_{i}"
value = f"value_{thread_id}_{i}"
# Add the key
name_resolve.add(key, value)
# Get and verify the key
retrieved = name_resolve.get(key)
assert retrieved == value
# Delete the key
name_resolve.delete(key)
# Verify deletion
try:
name_resolve.get(key)
results["failures"] += 1
except NameEntryNotFoundError:
results["success"] += 1
except Exception as e:
print(f"Thread {thread_id} failed: {e}")
results["failures"] += 1
# Run worker threads
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(worker, i) for i in range(num_threads)]
for future in futures:
future.result()
# Verify most operations succeeded
assert (
results["failures"] <= results["success"] * 0.1
) # Allow up to 10% failure rate
def test_add_subentry_uniqueness(name_resolve):
"""Test that add_subentry generates unique names."""
# Add multiple subentries to the same root
num_entries = 100
entries = set()
for _ in range(num_entries):
sub_name = name_resolve.add_subentry("test_root", "value")
entries.add(sub_name)
# Verify all entries are unique
assert len(entries) == num_entries
def test_wait_with_concurrent_delete(name_resolve):
"""Test wait behavior when a key is added and then deleted before wait completes."""
def add_then_delete():
time.sleep(0.1)
name_resolve.add("test_wait_key", "test_value")
time.sleep(0.1)
name_resolve.delete("test_wait_key")
thread = threading.Thread(target=add_then_delete, daemon=True)
thread.start()
# Wait with a timeout long enough to capture the key
value = name_resolve.wait("test_wait_key", timeout=2.0, poll_frequency=0.05)
assert value == "test_value"
# Wait for the thread to complete
thread.join()
# Verify the key was deleted
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("test_wait_key")
def test_wait_edge_cases(name_resolve):
"""Test edge cases for the wait method."""
# Test with invalid timeout values
with pytest.raises(TimeoutError):
name_resolve.wait("nonexistent_key", timeout=0)
# Test with negative timeout (should behave like timeout=None)
with pytest.raises(TimeoutError):
name_resolve.wait("nonexistent_key", timeout=-1, poll_frequency=0.01)
# Test with very small poll frequency
with pytest.raises(TimeoutError):
name_resolve.wait("nonexistent_key", timeout=0.1, poll_frequency=0.001)
def test_watch_names_multiple_keys(name_resolve):
"""Test watching multiple keys."""
callback_count = 0
def callback():
nonlocal callback_count
callback_count += 1
# Add test keys
name_resolve.add("watch_key1", "value1")
name_resolve.add("watch_key2", "value2")
# Watch both keys
name_resolve.watch_names(["watch_key1", "watch_key2"], callback, poll_frequency=0.1)
# Delete one key
time.sleep(0.2) # Ensure watcher is ready
name_resolve.delete("watch_key1")
# Wait for callback
time.sleep(0.5)
# Delete second key
name_resolve.delete("watch_key2")
# Wait for callback
time.sleep(1.0)
# Callback should have been called exactly once (when the last key is deleted)
assert callback_count == 1
def test_thread_safety_of_watch_thread_run(name_resolve):
"""Test thread safety of _watch_thread_run."""
# Mock the get method to simulate race conditions
original_get = name_resolve.get
def mock_get(name):
# First call returns normally, second call raises exception
mock_get.counter += 1
if mock_get.counter % 2 == 0:
raise NameEntryNotFoundError(f"Key not found: {name}")
return original_get(name)
mock_get.counter = 0
# Create a callback function that tracks calls
callback_called = False
def callback():
nonlocal callback_called
callback_called = True
# Add a test key
name_resolve.add("test_key", "test_value")
# Patch the get method
with patch.object(name_resolve, "get", side_effect=mock_get):
# Call _watch_thread_run directly
name_resolve._watch_thread_run("test_key", callback, 0.1, 1)
# Verify callback was called
assert callback_called
def test_keepalive_ttl(name_resolve):
"""Test keepalive_ttl functionality."""
# Skip if not Etcd3NameRecordRepository, as TTL might only be fully supported there
if "Etcd3NameRecordRepository" not in name_resolve.__class__.__name__:
pytest.skip("keepalive_ttl test is specific to Etcd3NameRecordRepository")
# Add a key with short TTL
name_resolve.add("ttl_key", "ttl_value", keepalive_ttl=2)
# Wait for less than the TTL - key should still exist
time.sleep(1)
assert name_resolve.get("ttl_key") == "ttl_value"
# Mock the keep-alive mechanism to simulate failure
with patch.object(
name_resolve._client, "refresh_lease", side_effect=Exception("Refresh failed")
):
# Wait longer than the TTL
time.sleep(3)
# Key should be gone if TTL is working
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("ttl_key")
def test_subentry_with_custom_uuid(name_resolve, monkeypatch):
"""Test add_subentry with a predictable UUID for deterministic testing."""
# Mock uuid.uuid4 to return a predictable value
mock_uuid = MagicMock()
mock_uuid.return_value = "12345678-1234-5678-1234-567812345678"
monkeypatch.setattr(uuid, "uuid4", mock_uuid)
# Add a subentry
sub_name = name_resolve.add_subentry("test_root", "sub_value")
# Verify the subentry has the expected name
assert sub_name == "test_root/12345678"
assert name_resolve.get(sub_name) == "sub_value"
def test_race_condition_in_add(name_resolve):
"""Test race condition when adding the same key concurrently."""
if isinstance(name_resolve, NfsNameRecordRepository):
pytest.skip("NFS repo cannot tackle race conditions")
# Define the number of concurrent threads
num_threads = 10
key = "race_condition_key"
success_count = 0
failure_count = 0
def add_with_same_key():
nonlocal success_count, failure_count
try:
name_resolve.add(key, f"value_{threading.get_ident()}", replace=False)
success_count += 1
except NameEntryExistsError:
failure_count += 1
# Run concurrent add operations
threads = [threading.Thread(target=add_with_same_key) for _ in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
# Verify only one thread succeeded
assert success_count == 1
assert failure_count == num_threads - 1
# Verify the key exists
assert name_resolve.get(key) is not None
def test_find_subtree_with_empty_result(name_resolve):
"""Test find_subtree behavior when no matching keys are found."""
# Ensure no keys exist with this prefix
prefix = "nonexistent_prefix"
# Call find_subtree
result = name_resolve.find_subtree(prefix)
# Verify result is an empty list, not None
assert result == []
assert isinstance(result, list)
def test_get_subtree_with_empty_result(name_resolve):
"""Test get_subtree behavior when no matching keys are found."""
# Ensure no keys exist with this prefix
prefix = "nonexistent_prefix"
# Call get_subtree
result = name_resolve.get_subtree(prefix)
# Verify result is an empty list, not None
assert result == []
assert isinstance(result, list)
def test_clear_subtree_with_nonexistent_prefix(name_resolve):
"""Test clear_subtree behavior with a nonexistent prefix."""
# Ensure no keys exist with this prefix
prefix = "nonexistent_prefix"
# Call clear_subtree - should not raise exception
name_resolve.clear_subtree(prefix)
# Add a key elsewhere and verify it's not affected
name_resolve.add("test_key", "test_value")
assert name_resolve.get("test_key") == "test_value"
def test_nested_subtrees(name_resolve):
"""Test behavior with deeply nested subtrees."""
# Create a deeply nested subtree
name_resolve.add("root/level1/level2/level3/key1", "value1")
name_resolve.add("root/level1/level2/key2", "value2")
name_resolve.add("root/level1/key3", "value3")
# Test get_subtree at different levels
assert set(name_resolve.get_subtree("root")) == {"value1", "value2", "value3"}
assert set(name_resolve.get_subtree("root/level1/level2")) == {"value1", "value2"}
# Test find_subtree at different levels
assert set(name_resolve.find_subtree("root/level1")) == {
"root/level1/level2/level3/key1",
"root/level1/level2/key2",
"root/level1/key3",
}
# Clear a subtree
name_resolve.clear_subtree("root/level1/level2")
# Verify only the specified subtree was cleared
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("root/level1/level2/level3/key1")
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("root/level1/level2/key2")
assert name_resolve.get("root/level1/key3") == "value3"
def test_corner_case_get_same_as_prefix(name_resolve):
"""Test get behavior when a key is both a prefix and a value."""
# Add entries
name_resolve.add("prefix", "parent_value")
name_resolve.add("prefix/child", "child_value")
# Verify both keys can be retrieved individually
assert name_resolve.get("prefix") == "parent_value"
assert name_resolve.get("prefix/child") == "child_value"
# Verify get_subtree includes both values
values = name_resolve.get_subtree("prefix")
assert set(values) == {"parent_value", "child_value"}
# Verify find_subtree includes both keys
keys = name_resolve.find_subtree("prefix")
assert set(keys) == {"prefix", "prefix/child"}
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is None, reason="ETCD3 not configured")
def test_etcd3_specific_features(name_resolve):
if not isinstance(name_resolve, Etcd3NameRecordRepository):
pytest.skip("ETCD3 specific test")
# Test the keepalive thread
name_resolve.add("test_key", "test_value", keepalive_ttl=2)
time.sleep(1) # Wait for the keepalive thread to refresh the lease
# Ensure the key still exists
value, _ = name_resolve._client.get("test_key")
assert value.decode("utf-8") == "test_value"
time.sleep(2) # Wait for the lease to expire
with pytest.raises(NameEntryNotFoundError):
name_resolve.get("test_key")
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is not None, reason="NFS specific test")
def test_nfs_specific_features(name_resolve):
"""Test features specific to NFS backend."""
from realhf.base.name_resolve import NfsNameRecordRepository
if not isinstance(name_resolve, NfsNameRecordRepository):
pytest.skip("NFS specific test")
# Test handling of stale file handles
name_resolve.add("test_key", "test_value")
original_open = open
call_count = 0
def mock_open(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 3: # Fail first 3 times
raise OSError(116, "Stale file handle")
return original_open(*args, **kwargs)
with patch("builtins.open", mock_open):
assert name_resolve.get("test_key") == "test_value"
assert call_count == 4

View File

@ -1,272 +0,0 @@
import os
import shutil
import tempfile
import time
import uuid
from unittest.mock import patch
import pytest
from realhf.base.name_resolve import (
NameEntryExistsError,
NameEntryNotFoundError,
NfsNameRecordRepository,
)
@pytest.fixture
def temp_nfs_root():
# Create a temporary directory to simulate NFS root
temp_dir = tempfile.mkdtemp()
original_root = NfsNameRecordRepository.RECORD_ROOT
NfsNameRecordRepository.RECORD_ROOT = temp_dir
yield temp_dir
# Cleanup
NfsNameRecordRepository.RECORD_ROOT = original_root
shutil.rmtree(temp_dir)
@pytest.fixture
def nfs_repo(temp_nfs_root):
repo = NfsNameRecordRepository()
yield repo
repo.reset()
def test_add_basic(nfs_repo):
# Test basic add functionality
nfs_repo.add("test_key", "test_value")
assert nfs_repo.get("test_key") == "test_value"
# Verify file was created
assert os.path.isfile(
os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test_key/ENTRY")
)
# Non-string value
nfs_repo.add(
"test_key", 123, replace=True
) # Should fail if non-string values aren't converted
assert nfs_repo.get("test_key") == str(123)
with pytest.raises(ValueError):
nfs_repo.add("", "value")
def test_add_with_replace(nfs_repo):
# Test add with replace=False (should raise)
nfs_repo.add("test_key", "test_value")
with pytest.raises(NameEntryExistsError):
nfs_repo.add("test_key", "new_value", replace=False)
# Test add with replace=True
nfs_repo.add("test_key", "new_value", replace=True)
assert nfs_repo.get("test_key") == "new_value"
def test_add_delete_on_exit(nfs_repo):
# Test delete_on_exit flag
nfs_repo.add("test_key1", "value1", delete_on_exit=True)
nfs_repo.add("test_key2", "value2", delete_on_exit=False)
assert "test_key1" in nfs_repo._NfsNameRecordRepository__to_delete
assert "test_key2" not in nfs_repo._NfsNameRecordRepository__to_delete
def test_delete(nfs_repo):
# Test deleting existing key
nfs_repo.add("test_key", "test_value")
nfs_repo.delete("test_key")
with pytest.raises(NameEntryNotFoundError):
nfs_repo.get("test_key")
# Test deleting non-existent key
with pytest.raises(NameEntryNotFoundError):
nfs_repo.delete("non_existent_key")
def test_delete_cleanup_dirs(nfs_repo):
# Test that empty parent directories are cleaned up
nfs_repo.add("test/path/key", "value")
assert os.path.isdir(os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test/path"))
nfs_repo.delete("test/path/key")
# Should clean up empty parent directories
assert not os.path.exists(
os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test/path")
)
assert not os.path.exists(os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test"))
def test_clear_subtree(nfs_repo):
# Test clearing a subtree
nfs_repo.add("test_root/key1", "value1")
nfs_repo.add("test_root/key2", "value2")
nfs_repo.add("test_root/sub/key3", "value3")
nfs_repo.add("other_root/key", "value")
nfs_repo.clear_subtree("test_root")
# Verify subtree is gone
assert nfs_repo.get_subtree("test_root") == []
assert nfs_repo.find_subtree("test_root") == []
# Verify other tree is intact
assert nfs_repo.get("other_root/key") == "value"
def test_get(nfs_repo):
# Test getting existing key
nfs_repo.add("test_key", "test_value")
assert nfs_repo.get("test_key") == "test_value"
# Test getting non-existent key
with pytest.raises(NameEntryNotFoundError):
nfs_repo.get("non_existent_key")
def test_get_stale_file_handle_recovery(nfs_repo):
# Test handling of stale file handles
nfs_repo.add("test_key", "test_value")
# Mock os.open to raise OSError with errno 116 (ESTALE) first few times
original_open = open
def mock_open(*args, **kwargs):
mock_open.call_count += 1
if mock_open.call_count <= 3: # Fail first 3 times
raise OSError(116, "Stale file handle")
return original_open(*args, **kwargs)
mock_open.call_count = 0
with patch("builtins.open", mock_open):
assert nfs_repo.get("test_key") == "test_value"
assert mock_open.call_count == 4
def test_get_subtree(nfs_repo):
# Test getting subtree values
nfs_repo.add("test_root/key1", "value1")
nfs_repo.add("test_root/key2", "value2")
nfs_repo.add("test_root/sub/key3", "value3")
values = nfs_repo.get_subtree("test_root")
assert set(values) == {"value1", "value2"}
def test_find_subtree(nfs_repo):
# Test finding subtree keys
nfs_repo.add("test_root/key1", "value1")
nfs_repo.add("test_root/key2", "value2")
nfs_repo.add("test_root/sub/key3", "value3")
keys = nfs_repo.find_subtree("test_root")
assert set(keys) == {"test_root/key1", "test_root/key2"}
assert keys == sorted(keys) # Should be sorted
def test_reset(nfs_repo):
# Test reset functionality
nfs_repo.add("test_key1", "value1", delete_on_exit=True)
nfs_repo.add("test_key2", "value2", delete_on_exit=False)
nfs_repo.reset()
# Only test_key1 should be deleted
with pytest.raises(NameEntryNotFoundError):
nfs_repo.get("test_key1")
assert nfs_repo.get("test_key2") == "value2"
def test_context_manager(nfs_repo):
# Test context manager functionality
with NfsNameRecordRepository() as repo:
repo.add("test_key", "test_value", delete_on_exit=True)
assert repo.get("test_key") == "test_value"
# Key should be deleted after context exits
with pytest.raises(NameEntryNotFoundError):
nfs_repo.get("test_key")
def test_destructor(nfs_repo):
# Test destructor functionality
repo = NfsNameRecordRepository()
repo.add("test_key", "test_value", delete_on_exit=True)
# Simulate object destruction
repo.__del__()
# Key should be deleted
with pytest.raises(NameEntryNotFoundError):
nfs_repo.get("test_key")
def test_add_subentry(nfs_repo):
# Test subentry creation
sub_name = nfs_repo.add_subentry("test_root", "sub_value")
assert sub_name.startswith("test_root/")
assert nfs_repo.get(sub_name) == "sub_value"
def test_wait(nfs_repo):
# Test wait functionality
import threading
def delayed_add():
time.sleep(0.1)
nfs_repo.add("test_key", "test_value")
job = threading.Thread(target=delayed_add, daemon=True)
job.start()
# Should return once key is added
assert nfs_repo.wait("test_key", timeout=2) == "test_value"
job.join()
# Test timeout
with pytest.raises(TimeoutError):
nfs_repo.wait("non_existent_key", timeout=0.1)
def test_watch_names(nfs_repo):
# Test watch functionality
callback_called = False
def callback():
nonlocal callback_called
callback_called = True
nfs_repo.add("test_key", "test_value")
nfs_repo.watch_names("test_key", callback)
# Delete the key to trigger callback
nfs_repo.delete("test_key")
# Wait for callback
time.sleep(5) # Give watcher thread time to execute
assert callback_called
def test_concurrent_access(nfs_repo):
# Test concurrent access to the same key
import threading
nfs_repo.add("test_key", "initial_value")
def modify_value():
for i in range(10):
current = nfs_repo.get("test_key")
nfs_repo.add("test_key", f"modified_{i}", replace=True)
time.sleep(0.01)
threads = [threading.Thread(target=modify_value) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Final value should be one of the modified values
final_value = nfs_repo.get("test_key")
assert final_value.startswith("modified_")

View File

@ -7,13 +7,8 @@ import pytest
from realhf.api.cli_args import (
ExperimentSaveEvalControl,
GenerationHyperparameters,
MFCConfig,
MicroBatchSpec,
ModelTrainEvalConfig,
ParallelismConfig,
PPOHyperparameters,
PromptOnlyDatasetConfig,
PromptAnswerDatasetConfig,
)
from realhf.base import cluster, testing
from realhf.experiments.common.sft_exp import SFTConfig
@ -74,6 +69,7 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
minbs = 32
exp_cfg = SFTConfig(
exp_ctrl=ExperimentSaveEvalControl(eval_freq_steps=2),
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
mode="local",
@ -86,8 +82,8 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
backend="mock_train",
),
dataset=PromptAnswerDatasetConfig(
train_path=str(save_path / "dataset.json"),
valid_path=str(save_path / "dataset.json"),
train_path=str(save_path / "dataset.jsonl"),
valid_path=str(save_path / "dataset.jsonl"),
max_seqlen=128,
train_bs_n_seqs=minbs,
valid_bs_n_seqs=minbs,

View File

@ -4,6 +4,7 @@
import json
import random
import uuid
import pytest
import torch
@ -50,9 +51,12 @@ def dataset(request, save_path):
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
qid = str(uuid.uuid4())
d = dict(
id=i,
id=qid,
query_id=qid,
prompt=generate_random_sentence(prompt_len),
solutions=["\\boxed{xxxxx}"],
answer=generate_random_sentence(random.randint(1, max_resp_len)),
pos_answers=[
generate_random_sentence(random.randint(1, max_resp_len))
@ -62,10 +66,12 @@ def dataset(request, save_path):
generate_random_sentence(random.randint(1, max_resp_len))
for _ in range(n_pairs)
],
task="math",
)
dataset.append(d)
with open(str(save_path / "dataset.json"), "w") as f:
json.dump(dataset, f)
with open(str(save_path / "dataset.jsonl"), "w") as f:
for d in dataset:
f.write(json.dumps(d) + "\n")
return dataset

View File

@ -0,0 +1,100 @@
import asyncio
import random
import uuid
import pytest
import torch
from realhf.api.cli_args import GenerationHyperparameters
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
BundledGenerationOutputs,
)
@pytest.fixture
def sglang_client(request):
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import terminate_process, wait_for_server
server_process, port = launch_server_cmd(
f"python -m sglang.launch_server --model-path {request.param} --host 0.0.0.0 --skip-tokenizer-init "
)
wait_for_server(f"http://localhost:{port}")
from realhf.impl.model.backend.sglang import SGLangAPIClient
client = SGLangAPIClient(
generate_url=f"http://localhost:{port}/generate",
update_weights_url=f"http://localhost:{port}/update_weights_from_disk",
)
yield client
terminate_process(server_process)
@pytest.mark.parametrize(
"sglang_client",
["/storage/openpsi/models/Qwen__Qwen2.5-7B-Instruct/"],
indirect=True,
)
@pytest.mark.parametrize("group_size", [16])
@pytest.mark.asyncio
async def test_batch_generate(sglang_client, group_size):
bs = 8
# genlen = 16384
genlen = 10
prompt_len = 100
async with sglang_client:
tasks = []
qids = []
for i in range(bs):
qid = str(uuid.uuid4())
prompt_ids = [random.randint(10, 100) for _ in range(prompt_len)]
gconfig = GenerationHyperparameters(
n=group_size,
max_new_tokens=genlen,
)
req = APIGenerateInput(
qid=qid,
prompt_ids=prompt_ids,
input_ids=prompt_ids,
gconfig=gconfig,
return_logprob=True,
)
tasks.append(sglang_client.async_add_generate_request(req, stream=False))
qids.append(qid)
outputs = {}
for r in asyncio.as_completed(tasks):
out = await r
outputs[out.qid] = out
results = [outputs[key] for key in qids]
assert all([isinstance(r, APIGenerateOutput) for r in results])
batch_token_ids = []
batch_logprobs = []
max_seqlen = -1
for x in results:
max_seqlen = max(max_seqlen, max(x.output_lens))
batch_token_ids += x.output_ids
batch_logprobs += x.output_logprobs
pad_token_id = 0
# To be consistent with our internal implementation,
# we should pad generated tokens and logprobs
batch_token_ids = [
t + [pad_token_id] * (max_seqlen - len(t)) for t in batch_token_ids
]
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
tokens = torch.tensor(batch_token_ids, dtype=torch.long, device="cpu")
assert tokens.shape == (bs * group_size, genlen)
logprobs = torch.tensor(batch_logprobs, dtype=torch.float32, device="cpu")
assert logprobs.shape == (bs * group_size, genlen)

View File

@ -0,0 +1,255 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import random
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
import pytest
from realhf.api.core.config import ModelName
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
from realhf.api.core.system_api import WorkerInformation
from realhf.base import constants, name_resolve, names, network, testing
from realhf.base.names import gen_servers
from realhf.system.gserver_manager import GserverManager
N_SERVERS = 2
UPDATE_WEIGHTS_CALL_COUNT = defaultdict(int)
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None
@pytest.fixture
def mock_servers():
from http import HTTPStatus
from threading import Thread
import uvicorn
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse, PlainTextResponse
ports = network.find_multiple_free_ports(N_SERVERS)
# Create mock server responses
servers = []
jobs = []
should_run = True
def run_app1():
app = FastAPI()
@app.get("/metrics")
async def metrics1():
n_requests = random.choice(list(range(4)))
return PlainTextResponse(
f"sglang:num_running_reqs {float(n_requests)}\nother_metric 123"
)
@app.post("/update_weights_from_disk/")
async def update_weights_from_disk1(req: UpdateWeightFromDiskReqInput):
UPDATE_WEIGHTS_CALL_COUNT[ports[0]] += 1
return ORJSONResponse(
{"success": True, "message": "Weights updated successfully"},
status_code=HTTPStatus.OK,
)
config = uvicorn.Config(
app,
host="localhost",
port=ports[0],
log_level="info",
)
server = uvicorn.Server(config)
servers.append(server)
while should_run:
server.run()
def run_app2():
app = FastAPI()
@app.get("/metrics")
async def metrics2():
n_requests = random.choice(list(range(4)))
return PlainTextResponse(
f"sglang:num_running_reqs {float(n_requests)}\nother_metric 123"
)
@app.post("/update_weights_from_disk/")
async def update_weights_from_disk2(req: UpdateWeightFromDiskReqInput):
UPDATE_WEIGHTS_CALL_COUNT[ports[1]] += 1
return ORJSONResponse(
{"success": True, "message": "Weights updated successfully"},
status_code=HTTPStatus.OK,
)
config = uvicorn.Config(
app,
host="localhost",
port=ports[1],
log_level="info",
)
server = uvicorn.Server(config)
servers.append(server)
while should_run:
server.run()
job1 = Thread(target=run_app1)
job2 = Thread(target=run_app2)
jobs = [job1, job2]
for job in jobs:
job.start()
yield ports
should_run = False
for server in servers:
server.should_exit = True
for job in jobs:
job.join()
@pytest.fixture
def gserver_manager(mock_servers):
testing.clear_name_resolve()
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
# Register mock servers in name resolve
server_urls = [f"http://localhost:{mock_servers[i]}/" for i in range(N_SERVERS)]
name = gen_servers(testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME)
name_resolve.add_subentry(name, server_urls[0])
name_resolve.add_subentry(name, server_urls[1])
# Mock requests.get for metrics endpoint
m = GserverManager()
config = GserverManagerConfig(
model_name=ModelName("default", 0),
n_servers=N_SERVERS,
schedule_policy="round_robin",
worker_info=WorkerInformation(
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
worker_type="gserver_manager",
worker_count=1,
worker_index=0,
),
)
m._configure(config)
# launch the server
m._poll()
yield m
m.exit()
@pytest.mark.asyncio
async def test_schedule_policy(gserver_manager):
# Test round-robin scheduling
from realhf.api.core.model_api import GenReqMeta
req_meta = GenReqMeta(
prompt_len=100,
group_size=2,
new_token_budget=1024,
predicted_new_tokens=None,
)
idx1 = gserver_manager._round_robin_schedule(req_meta)
assert idx1 == 0
idx2 = gserver_manager._round_robin_schedule(req_meta)
assert idx2 == 1
idx3 = gserver_manager._round_robin_schedule(req_meta)
assert idx3 == 0
@pytest.mark.asyncio
async def test_weight_update(gserver_manager):
from fastapi.testclient import TestClient
from realhf.api.core.model_api import GenReqMeta
client = TestClient(gserver_manager.app)
# Set up a new parameter version
name = names.model_version(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME, "default"
)
name_resolve.add(name, "1")
global UPDATE_WEIGHTS_CALL_COUNT
UPDATE_WEIGHTS_CALL_COUNT.clear()
req_meta = GenReqMeta(
prompt_len=100,
group_size=2,
new_token_budget=1024,
predicted_new_tokens=None,
)
client.post("/schedule_request", json=dataclasses.asdict(req_meta))
assert gserver_manager._last_param_realloc_step == 1
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
for v in UPDATE_WEIGHTS_CALL_COUNT.values():
assert v == 1
# weights updated, no more weights update
client.post("/schedule_request", json=dataclasses.asdict(req_meta))
assert gserver_manager._last_param_realloc_step == 1
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
for v in UPDATE_WEIGHTS_CALL_COUNT.values():
assert v == 1
UPDATE_WEIGHTS_CALL_COUNT.clear()
# Test with no version set
name_resolve.delete(name)
def test_server_lifecycle(gserver_manager):
# Test that the server starts and stops properly
assert gserver_manager.thread is not None
assert gserver_manager.thread.is_alive()
# Test exit hook
gserver_manager.exit()
time.sleep(0.1) # Give thread time to stop
assert not gserver_manager.thread.is_alive()
@pytest.mark.asyncio
async def test_http_server_endpoints(gserver_manager):
# Test the FastAPI endpoints
from fastapi.testclient import TestClient
from realhf.api.core.model_api import GenReqMeta
client = TestClient(gserver_manager.app)
# Test schedule_request endpoint
req_meta = GenReqMeta(
prompt_len=100,
group_size=2,
new_token_budget=1024,
predicted_new_tokens=None,
)
# Test round-robin behavior through the endpoint
responses = set()
for _ in range(10):
response = client.post("/schedule_request", json=dataclasses.asdict(req_meta))
assert response.status_code == 200
responses.add(response.json()["url"])
# Should have used all available servers
assert responses == set(gserver_manager.server_urls)
def test_unique_server_urls(gserver_manager):
# Ensure server URLs are unique
assert len(set(gserver_manager.server_urls)) == len(gserver_manager.server_urls)

View File

@ -0,0 +1,140 @@
import asyncio
import random
import string
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from transformers import PreTrainedTokenizerFast
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
GenerationHyperparameters,
LLMAPIClient,
)
from realhf.base import constants, name_resolve, names, testing
from realhf.base.names import gen_server_manager
from realhf.system.partial_rollout import PartialRolloutManager
class MockLLMAPIClient(LLMAPIClient):
def setup(self):
pass
async def close(self):
await asyncio.sleep(0.5)
async def async_add_generate_request(
self, req: APIGenerateInput, stream: bool = True
) -> APIGenerateOutput:
await asyncio.sleep(1)
output_len = req.gconfig.max_new_tokens
out = APIGenerateOutput.from_input(req)
out.output_ids = [
[
random.randint(0, testing.TESTING_MODEL_VOCAB_SIZE - 1)
for _ in range(output_len)
]
for _ in range(req.gconfig.n)
]
if req.return_logprob:
out.output_logprobs = [
[random.random() for _ in range(output_len)]
for _ in range(req.gconfig.n)
]
out.no_eos = [True for _ in range(req.gconfig.n)]
return out
async def async_update_weights_from_disk(self, path):
await asyncio.sleep(10)
async def get_cur_version(self):
await asyncio.sleep(0)
return 1
new_tokens_per_chunk = 16
@pytest.fixture
def partial_rollout_manager():
# Set up mocked tokenizer
mock_tokenizer = MagicMock(spec=PreTrainedTokenizerFast)
mock_tokenizer.pad_token_id = 0
mock_tokenizer.eos_token_id = 1
# Set up mocked request and reply queues
request_queue = asyncio.Queue()
reply_queue = asyncio.Queue()
testing.clear_name_resolve()
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
name = gen_server_manager(testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME)
name_resolve.add(name, "http://fake.com")
global new_tokens_per_chunk
# Initialize PartialRolloutManager
manager = PartialRolloutManager(
worker_index=0,
request_queue=request_queue,
reply_queue=reply_queue,
new_tokens_per_chunk=new_tokens_per_chunk,
tokenizer=mock_tokenizer,
)
yield manager
# Cleanup if needed
@pytest.mark.asyncio
async def test_run_step_calls_issue_generation(partial_rollout_manager):
"""
Test `run_step` calls `_issue_generation` the correct number of times.
"""
with (
patch(
"realhf.system.partial_rollout.PartialRolloutManager._schedule_request",
new_callable=AsyncMock,
),
patch(
"realhf.impl.model.backend.sglang.SGLangAPIClient",
new_callable=lambda: MockLLMAPIClient,
),
):
# Define inputs
gen_random_qid = lambda length=8: "".join(
random.choice(string.ascii_letters + string.digits) for _ in range(length)
)
group_size = 10
test_prompt_ids = list(range(100))
test_max_new_tokens = 128
assert test_max_new_tokens % new_tokens_per_chunk == 0
test_raw_gconfig = GenerationHyperparameters(
n=group_size, max_new_tokens=test_max_new_tokens
)
q_len = 8 # query number
test_qids = [f"{gen_random_qid(10)}-{i}" for i in range(q_len)]
for test_qid in test_qids:
partial_rollout_manager.request_queue.put_nowait(
(test_qid, test_prompt_ids.copy(), test_raw_gconfig.new())
)
with patch.object(
partial_rollout_manager,
"_issue_generation",
side_effect=partial_rollout_manager._issue_generation,
) as mock_issue_generation:
while partial_rollout_manager.reply_queue.qsize() != q_len:
await partial_rollout_manager.run_step()
assert (
mock_issue_generation.call_count
== q_len
* group_size
* test_max_new_tokens
// partial_rollout_manager.new_tokens_per_chunk
)

View File

@ -0,0 +1,174 @@
import asyncio
import copy
from asyncio.queues import QueueEmpty
from unittest.mock import patch
import pytest
from realhf.api.core.config import (
AgentAbstraction,
DatasetAbstraction,
EnvServiceAbstraction,
ModelName,
)
from realhf.api.core.model_api import (
BundledGenerationOutputs,
GenerationHyperparameters,
)
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
from realhf.api.core.system_api import WorkerInformation
from realhf.base import constants, name_resolve, names, network, testing
from realhf.system.push_pull_stream import NameResolvingZmqPusher
from realhf.system.rollout_worker import RolloutWorker
from tests.fixtures import *
N_PULLERS = 3
class MockPartialRolloutManager:
def __init__(self, request_queue, reply_queue, **kwargs):
self.request_queue = request_queue
self.reply_queue = reply_queue
self.internal_queue = []
def get_num_gen_requests(self):
return len(self.internal_queue)
async def run_step(self):
async def poll_fresh_requests():
for _ in range(8):
try:
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
assert isinstance(qid, str)
assert isinstance(prompt_token_ids, list)
assert all(isinstance(x, int) for x in prompt_token_ids)
assert isinstance(gconfig, GenerationHyperparameters)
self.internal_queue.append(qid)
except QueueEmpty:
await asyncio.sleep(0.01)
async def poll_old_requests():
for _ in range(8):
if random.random() < 0.5 and len(self.internal_queue) > 0:
# responses may not return in order
idx = random.randint(0, len(self.internal_queue) - 1)
qid = self.internal_queue.pop(idx)
out = BundledGenerationOutputs(
qid=qid,
prompt_ids=[1],
output_ids=[[2], [3]],
seqs=[[1, 2], [1, 3]],
logprobs=[[0.0, 0.1], [0.0, 2.0]],
no_eos=[True, True],
version_start=[0, 1],
version_end=[1, 2],
)
await self.reply_queue.put(out)
else:
await asyncio.sleep(0.01)
await asyncio.gather(poll_fresh_requests(), poll_old_requests())
@pytest.fixture
def rollout_workers(request):
testing.clear_name_resolve()
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
# add name resolve to make zmq pusher happy
puller_ports = network.find_multiple_free_ports(N_PULLERS)
for puller_index in range(N_PULLERS):
name = names.stream_pullers(
testing._DEFAULT_EXPR_NAME,
testing._DEFAULT_TRIAL_NAME,
)
name_resolve.add_subentry(name, str(puller_index))
name = names.push_pull_stream(
testing._DEFAULT_EXPR_NAME,
testing._DEFAULT_TRIAL_NAME,
f"puller{puller_index}",
)
name_resolve.add(name, f"localhost:{puller_ports[puller_index]}")
with (
patch.object(NameResolvingZmqPusher, "push", return_value=None) as mock_push,
patch(
"realhf.system.rollout_worker.PartialRolloutManager",
new=MockPartialRolloutManager,
),
):
ms = [RolloutWorker() for _ in range(request.param)]
yield ms
@pytest.mark.parametrize("rollout_workers", [1, 2, 3], indirect=True)
@pytest.mark.parametrize("offpolicyness", [0, 1, 4])
@pytest.mark.asyncio
async def test_offpolicyness_control(
rollout_workers, save_path, dataset, offpolicyness
):
train_batch_size = 8
config = RolloutWorkerConfig(
base_seed=0,
model_name=ModelName("default", 0),
max_head_offpolicyness=offpolicyness,
train_batch_size=train_batch_size,
tokenizer_path="/storage/openpsi/models/Qwen__Qwen2.5-1.5B/",
new_tokens_per_chunk=1024,
max_concurrent_rollouts=128,
env=EnvServiceAbstraction("null"),
agent=AgentAbstraction("null", dict(episode_length=5, traj_size=5)),
datasets=[
DatasetAbstraction(
"prompt", args=dict(dataset_path=str(save_path / "dataset.jsonl"))
)
],
worker_info=WorkerInformation(
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
worker_type="rollout_worker",
worker_count=N_PULLERS * 2,
worker_index=0,
),
)
for i, m in enumerate(rollout_workers):
config = copy.deepcopy(config)
config.worker_info.worker_index = i
m._configure(config)
for i in range(10 * (offpolicyness + 1)):
for m in rollout_workers:
await m._poll_async()
# Ensure that data is not overly produced
for m in rollout_workers:
assert m.agent.ACT_GET_CNT > 0
assert (
(offpolicyness + 1) * train_batch_size >= m.push_stream.push.call_count > 0
)
# Increase the model version by 1
version_name = names.model_version(
constants.experiment_name(),
constants.trial_name(),
config.model_name.role,
)
name_resolve.add(version_name, "1")
# Run the rollout worker again
for i in range(10 * (offpolicyness + 1)):
for m in rollout_workers:
await m._poll_async()
# The rollout worker should produce new samples
for m in rollout_workers:
assert (offpolicyness + 2) * train_batch_size >= m.push_stream.push.call_count
assert (offpolicyness + 1) * train_batch_size < m.push_stream.push.call_count
# Final clean up
name_resolve.delete(version_name)
for m in rollout_workers:
await m._exit_async_tasks()

View File

@ -0,0 +1,158 @@
import queue
import random
import time
from unittest.mock import patch
import pytest
import torch
from torch.utils.data import DataLoader
from realhf.api.core import config as config_api
from realhf.api.core import data_api
from realhf.base import constants, testing
from tests.fixtures import *
@pytest.fixture
def prompt_dataset_cfg(dataset, tokenizer):
return [
config_api.DatasetAbstraction(
type_="prompt",
args=dict(
max_length=32,
dataset_builder=lambda: dataset,
),
)
]
class MockPuller:
def __init__(self, *args, **kwargs):
self.pull_count = 0
def pull(self, timeout_ms: float):
self.pull_count += 1
if self.pull_count % 2 == 1: # Return data every other call
return [
data_api.SequenceSample.as_json_compatible(
data_api.SequenceSample(
ids=[str(123)],
data={"packed_prompts": torch.tensor([1, 2, 3])},
keys=set(["packed_prompts"]),
seqlens=dict(packed_prompts=[[3]]),
dtypes=dict(packed_prompts=torch.long),
trailing_shapes=dict(packed_prompts=()),
)
)
]
raise queue.Empty()
@pytest.fixture
def mock_puller(monkeypatch):
monkeypatch.setattr(
"realhf.system.stream_dataset.NameResolvingZmqPuller", MockPuller
)
def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller):
import realhf.impl.dataset # isort: skip
from realhf.api.core.data_api import make_dataset
from realhf.system.stream_dataset import PullerStreamDataset
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
testing.clear_name_resolve()
util = data_api.DatasetUtility(
seed=42,
dp_rank=0,
world_size=1,
tokenizer=tokenizer,
)
# Test initialization
dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=100)
assert len(dataset) > 0 # Should have non-zero size from prompt dataset
assert dataset.data_queue.empty()
# Test data pulling
time.sleep(0.2) # Give worker thread time to pull some data
items = dataset[0] # This should get all available items from queue
assert len(items) > 0
assert isinstance(items[0], data_api.SequenceSample)
# Test queue behavior with multiple gets
time.sleep(0.2)
items1 = dataset[0]
items2 = dataset[0]
assert len(items1) + len(items2) > 0
# Test cleanup
del dataset
def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer):
from realhf.system.stream_dataset import PullerStreamDataset
testing.clear_name_resolve()
util = data_api.DatasetUtility(
seed=42,
dp_rank=0,
world_size=1,
tokenizer=tokenizer,
)
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
mock_puller.return_value.pull.side_effect = queue.Empty()
dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=10)
# Should handle timeout gracefully
assert dataset[0] == []
del dataset
def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_puller):
from realhf.system.stream_dataset import PullerStreamDataset
testing.clear_name_resolve()
util = data_api.DatasetUtility(
seed=42,
dp_rank=0,
world_size=1,
tokenizer=tokenizer,
)
dataset = PullerStreamDataset(util, prompt_dataset_cfg)
assert not dataset._stop_event.is_set()
# Trigger stop event and verify thread stops
dataset._stop_event.set()
time.sleep(0.1)
assert not dataset.worker_thread.is_alive()
del dataset
def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, tokenizer):
from realhf.system.stream_dataset import PullerStreamDataset
testing.clear_name_resolve()
util = data_api.DatasetUtility(
seed=42,
dp_rank=0,
world_size=1,
tokenizer=tokenizer,
)
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
mock_puller.return_value.pull.side_effect = Exception("Test error")
dataset = PullerStreamDataset(util, prompt_dataset_cfg)
time.sleep(0.1) # Give thread time to crash
assert not dataset.worker_thread.is_alive()
del dataset