This commit is contained in:
bowei.fw 2025-03-20 15:24:31 +08:00
commit f8586e47c8
40 changed files with 1829 additions and 782 deletions

View File

@ -13,7 +13,6 @@ from functioncall.base import logging
logger = logging.getLogger("Functioncall")
FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
"FUNCTIONCALL_SERVICE_DOMAIN",
"",
@ -31,8 +30,8 @@ async def async_invoke_function(
function_name: str,
timeout: aiohttp.ClientTimeout,
payload: Dict[str, Any] = None,
max_retries: int = 3,
initial_retry_interval: float = 0.1,
max_retries: int = 100,
initial_retry_interval: float = 0.5,
max_retry_interval: float = 10.0,
):
if payload is None:
@ -63,7 +62,7 @@ async def async_invoke_function(
except asyncio.TimeoutError as e:
logger.warning(
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}, payload: {payload}"
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}"
)
break
@ -85,7 +84,7 @@ async def async_invoke_function(
async def batch_function_call_async(
payload_list, function_name, timeout, concurrency=1000
payload_list, function_name, timeout, concurrency=1500
):
connector = aiohttp.TCPConnector(limit=0)
async with aiohttp.ClientSession(connector=connector) as session:
@ -120,7 +119,7 @@ async def batch_function_call_async(
p90 = calculate_percentile(elapsed_times, 90)
p99 = calculate_percentile(elapsed_times, 99)
logger.info(
f"Longest functioncall took {max_elapsed:.4f} seconds, header: {max_elapsed_header}, p50: {p50}, p90: {p90}, p99: {p99}"
f"Longest functioncall {function_name} took {max_elapsed:.4f} seconds, header: {max_elapsed_header}, timeout: {timeout}, p50: {p50}, p90: {p90}, p99: {p99}"
)
return data_list
@ -129,19 +128,21 @@ async def batch_function_call_async(
def get_function_name(runtime_type):
if runtime_type == "python_code":
return "realhf_code_verify"
if runtime_type == "python_live_code_bench":
return "python_live_code_bench"
elif runtime_type == "python_math":
return "realhf_math_verify"
return "python_math"
return "empty_code"
def batch_function_call(payload_list, runtime_type, timeout=30):
def batch_function_call(payload_list, runtime_type, timeout):
start_time = time.time()
function_name = get_function_name(runtime_type)
result = asyncio.run(
batch_function_call_async(payload_list, function_name, timeout)
)
execution_time = time.time() - start_time
logger.debug(
logger.info(
f"Batch function call done, runtime type: {runtime_type}, batch size: {len(payload_list)}, cost: {execution_time * 1000:.0f} ms"
)
return result

View File

@ -714,7 +714,7 @@ def reliability_guard(maximum_memory_bytes=None):
import builtins
builtins.exit = None
# builtins.exit = None
builtins.quit = None
import os

View File

@ -29,7 +29,7 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
# parse one problem
row = json.loads(line.strip().decode("utf-8"))
query_id = str(row["id"])
query_id = str(row.get("id", row.get("query_id")))
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
inputs = input_output.get("inputs", [])
outputs = input_output.get("outputs", [])
@ -66,7 +66,9 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
global_problems = None
def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_testcase=6):
def code_verify(
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
):
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
@ -116,14 +118,11 @@ def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_test
value = 0
if rsp and "result" in rsp and not any(x != True for x in rsp["result"]):
value = 1
else:
logger.debug(
f"Functioncall code verify failed, query index: {query_index}, query id: {query_id}, results: {rsp}"
f"Functioncall code verify not passed, query index: {query_index}, query id: {query_id}, results: {rsp}"
)
# logger.debug(f"query index: {idx}, value: {value}, results[query_index]: {results[query_index]}")
results[query_index] = results[query_index] and value
return results
@ -133,6 +132,15 @@ if __name__ == "__main__":
def create_test_params(count=10):
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
codes, query_ids = [], []
idx = 0
for query_id, problems in global_problems.items():
@ -149,6 +157,6 @@ if __name__ == "__main__":
return codes, query_ids
codes, query_ids = create_test_params(1000)
result = code_verify(codes, query_ids, True, 100)
codes, query_ids = create_test_params(100)
result = code_verify(codes, query_ids, True)
print(result)

View File

@ -1,4 +1,5 @@
import json
import time
from parser import extract_answer
from grader import math_equal
@ -6,11 +7,8 @@ from grader import math_equal
def process_results(answer, solution):
extracted_answer = extract_answer(answer, "math", use_last_number=False)
extracted_solution = extract_answer(solution, "math", use_last_number=True)
extracted_solution = solution
print(
f"extracted_answer: {extracted_answer}, extracted_solution: {extracted_solution}, equal: {math_equal(extracted_answer, extracted_solution)}"
)
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
retval = 0
elif math_equal(extracted_answer, extracted_solution, timeout=True):
@ -24,16 +22,19 @@ def process_results(answer, solution):
def handle(event, context):
answers = event.get("answers", "")
solutions = event.get("solutions", "")
# print(f"math payload:{event}\n")
# answers and solutions are json lists, and call process_results then collect result into a list
if isinstance(answers, str):
answers = json.loads(answers)
if isinstance(solutions, str):
solutions = json.loads(solutions)
query_ids = event.get("query_ids", "")
results = []
for answer, solution in zip(answers, solutions):
results.append(process_results(answer, solution))
for answer, solution, query_id in zip(
answers,
solutions,
query_ids,
):
start_time = time.time()
result = process_results(answer, solution)
results.append(result)
print(
f"query_id: {query_id}, result: {result}, current cost: {(time.time() - start_time) * 1000:.0f} ms"
)
return results

View File

@ -22,7 +22,7 @@ def loadJson(dataDir):
id2info = None
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) -> List:
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
global id2info
if id2info is None:
id2info = loadJson(
@ -43,7 +43,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
base_query_id = query_id.split("@idx:")[0]
info = id2info[base_query_id]
for cur_solution in info["solutions"]:
for cur_solution in info["answers"]:
parameters.append((generated, cur_solution, idx))
query_indices.append(idx)
@ -58,7 +58,6 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
"query_ids": [query_ids[i] for i in indices],
}
# print(batch_args)
batch_args_list.append(batch_args)
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
@ -67,15 +66,15 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
# Map results back to original indices
index = 0
for batch_idx, results in enumerate(results_batch):
query_index = query_indices[index]
if not isinstance(results, list) or len(results) == 0:
index += len(batch_args_list[batch_idx]["answers"])
logger.warning(
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_index}, params: {batch_args_list[batch_idx]['answers']}."
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_indices[index]}, params: {batch_args_list[batch_idx]['answers']}."
)
continue
for result in results:
query_index = query_indices[index]
if (
isinstance(result, list)
and len(result) > 0
@ -90,23 +89,33 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
index += 1
logger.info(
f"verify math with query size={len(query_ids)}, takes {time.time() - start_time:.4f} seconds, result: {labels}"
f"verify math with query size={len(query_ids)}, takes {time.time() - start_time:.4f} seconds"
)
return labels
if __name__ == "__main__":
sample = {
"prompt": "",
"query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
"answer": "\\boxed{-\\frac{2}{3}}",
}
start_time = time.time()
batch_size = 10
result = math_verify(
[sample["answer"]] * batch_size, [sample["query_id"] for _ in range(batch_size)]
)
# sample = {
# "prompt": "",
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
# "answer": "\\boxed{-\\frac{1}{30}}",
# }
hint = f"batch_size: {batch_size}, total time : {(time.time() - start_time) * 1000:.0f} ms"
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
answers = []
query_ids = []
for id, value in id2info.items():
answers.append(value["solutions"][0])
query_ids.append(id)
start_time = time.time()
result = math_verify(answers[:200], query_ids[:200])
print(result)
print(hint)

View File

@ -684,7 +684,6 @@ class SequenceSample:
class DataBatchMeta:
dp_rank: int
meta_sample: SequenceSample | None
is_final_batch: bool
@dataclasses.dataclass

View File

@ -3,10 +3,12 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import abc
import asyncio
import dataclasses
import keyword
from typing import *
import aiohttp
import numpy as np
import torch
import torch.utils.data
@ -106,6 +108,139 @@ class GenerationHyperparameters:
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
)
def new(self, **kwargs):
args = dataclasses.asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
@dataclasses.dataclass
class APIGenerateInput:
qid: Hashable
group_idx: int
prompt_ids: List[int]
input_ids: List[int]
gconfig: GenerationHyperparameters
stop_token_ids: List[int] = dataclasses.field(default_factory=list)
return_logprob: bool = False
@dataclasses.dataclass
class APIGenerateOutput:
qid: Hashable
group_idx: int
prompt_ids: List[int]
input_ids: List[int]
output_ids: List[int] = dataclasses.field(default_factory=list)
output_logprobs: List[int] = dataclasses.field(default_factory=list)
no_eos: bool = True
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = dataclasses.field(
default_factory=list
) # List of inter-token latencies
error: str = ""
@classmethod
def from_input(cls, inp: APIGenerateInput):
return cls(
qid=inp.qid,
group_idx=inp.group_idx,
prompt_ids=inp.prompt_ids,
input_ids=inp.input_ids,
)
@property
def output_len(self):
return len(self.output_ids)
@property
def input_len(self):
return len(self.input_ids)
@property
def prompt_len(self):
return len(self.prompt_ids)
@property
def gen_len(self):
return self.output_len + self.input_len - self.prompt_len
@dataclasses.dataclass
class BundledGenerationOutputs:
qid: Hashable
prompt_ids: List[int]
seqs: List[List[int]]
no_eos: List[bool]
@classmethod
def from_single(cls, outputs: List[APIGenerateOutput]):
assert len(set(o.qid for o in outputs)) == 1
return cls(
qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids,
seqs=[o.input_ids + o.output_ids for o in outputs],
no_eos=[o.no_eos for o in outputs],
)
@property
def seqlens(self):
return [len(seq) for seq in self.seqs]
@property
def prompt_len(self):
return len(self.prompt_ids)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
class LLMAPIClient:
def __init__(
self, generate_url: str, update_weights_url: str, concurrency_limit: int = -1
):
self.update_weights_url = update_weights_url
self.generate_url = generate_url
self.concurrency_limit = concurrency_limit
self.session: aiohttp.ClientSession
self.semaphore: asyncio.Semaphore
async def __aenter__(self):
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
self.session = aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT,
connector=conn,
read_bufsize=1024 * 1024 * 10,
)
if self.concurrency_limit > 0:
self.semaphore = asyncio.Semaphore(self.concurrency_limit)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def async_add_generate_request(
self, req: APIGenerateInput, stream: bool = True
) -> APIGenerateOutput:
if self.concurrency_limit > 0:
async with self.semaphore:
return await self._do_generate(req, stream=stream)
else:
return await self._do_generate(req, stream=stream)
async def _do_generate(
self, req: APIGenerateInput, stream: bool = True
) -> APIGenerateOutput:
raise NotImplementedError()
async def async_update_weights_from_disk(self, path):
raise NotImplementedError()
@dataclasses.dataclass
class ReaLMoEConfig:

View File

@ -131,6 +131,65 @@ class vLLMConfig:
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class SGLangConfig:
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton"
sampling_backend: Optional[str] = None
context_length: Optional[int] = None
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
@dataclasses.dataclass
class DistributedDataParallelConfig:
"""Configuration for Megatron DistributedDataParallel.
Some default options have been overwritten.
"""
grad_reduce_in_fp32: bool = False
overlap_grad_reduce: bool = True
overlap_param_gather: bool = False
align_param_gather: bool = False
use_distributed_optimizer: bool = True
check_for_nan_in_grad: bool = False
bucket_size: Optional[int] = None
average_in_collective: bool = False
fp8_param_gather: bool = False
@dataclasses.dataclass
class MegatronConfig:
"""When using the DistributedOptimizer of Megatron, parameters and
@ -177,12 +236,11 @@ class MegatronConfig:
make it functionally correct. The DeepSpeed code is too hard to read and modify.
"""
overlap_grad_reduce: bool = True
overlap_param_gather: bool = False
accumulate_allreduce_grads_in_fp32: bool = False
# addtional args
additional_config: Dict = dataclasses.field(default_factory=dict)
ddp: DistributedDataParallelConfig = dataclasses.field(
default_factory=DistributedDataParallelConfig
)
# Don't use MegatronOptimizerConfig here because OmegaConf
# does not recognize the annotation "torch.dtype"
@dataclasses.dataclass
@ -233,6 +291,7 @@ class ModelTrainEvalConfig:
)
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
init_from_scratch: bool = False
init_critic_from_actor: bool = False

View File

@ -7,10 +7,13 @@ import functools
import json
import multiprocessing
import os
import pickle
import re
import socket
import subprocess
try:
import uvloop
uvloop.install()
except (ModuleNotFoundError, ImportError):
pass
import torch
from omegaconf import OmegaConf
@ -60,7 +63,7 @@ def main_worker(args):
worker_index_start + args.wprocs_per_jobstep,
args.wprocs_in_job + args.wproc_offset,
)
if args.worker_type != "model_worker":
if args.worker_type == "master_worker":
try:
# CUDA_VISIBLE_DEVICES is set by slurm on PPU nodes
# we need to remove it on CPU workers

View File

@ -80,6 +80,7 @@ TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets"
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler"
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc"
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang"
TORCH_EXTENSIONS_DIR = (
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
)
@ -165,6 +166,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
# _model_name will be changed in the model_scope context manager
_model_name: "ModelName" = None
@ -177,6 +179,9 @@ _grids: Dict["ModelName", "ParallelGrid"] = {}
_pgroups: Dict["ModelName", Any] = (
{}
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
_cpu_pgroups: Dict["ModelName", Any] = (
{}
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
_pgroup_ranks: Dict["ModelName", List[int]] = {}
_self_group = None
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
@ -259,6 +264,13 @@ def set_parallelism_group(model_name: "ModelName", pgroup, ranks):
_pgroup_ranks[model_name] = ranks
def set_cpu_parallelism_group(model_name: "ModelName", pgroup):
global _cpu_pgroups
if model_name in _cpu_pgroups:
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
_cpu_pgroups[model_name] = pgroup
def set_self_group(pgroup):
global _self_group
if _self_group is not None:
@ -382,6 +394,15 @@ def parallelism_group():
return _pgroups[_model_name]
def cpu_parallelism_group():
"""Returns the GLOO 3D parallelism group of a specific model."""
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
if _cpu_pgroups.get(_model_name, None) is None:
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
return _cpu_pgroups[_model_name]
def parallelism_group_ranks():
if _model_name is None:
raise RuntimeError("Global constant `model_name` is accessed before set.")
@ -451,6 +472,10 @@ def prev_pipe_stage():
) % pipe_parallel_world_size()
def is_dp_head():
return is_last_pipe_stage() and model_parallel_rank() == 0
def model_parallel_rank() -> int:
"""Return the rank inside the tensor parallelism group."""
try:

View File

@ -2,7 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import heapq
import bisect
import itertools
from typing import Any, List, Tuple, Union
@ -159,8 +159,9 @@ def _ffd_allocate(
as possible.
1. Sort the numbers in reverse order.
2. If the number of groups is less than, create a new group.
3. If the new number fits into the smallest group, add it into the group.
2. If the number of groups is less than `min_groups`, create a new group.
3. For a new number, find all groups with the capacity to hold the new number.
Put the new number into the group with the smallest size.
4. Otherwise, create a new group.
"""
value_indices = np.argsort(-values)
@ -172,12 +173,17 @@ def _ffd_allocate(
len(group_values) < min_groups
or group_values[0][0] + values[idx] > capacity
):
heapq.heappush(group_values, (float(values[idx]), group_cnt))
bisect.insort(group_values, (float(values[idx]), group_cnt))
group_indices.append([idx])
group_cnt += 1
else:
v, group_idx = heapq.heappop(group_values)
heapq.heappush(group_values, (float(v + values[idx]), group_idx))
i = bisect.bisect_right(group_values, (capacity - values[idx], len(values)))
candidates = [group_values[j][1] for j in range(i)]
lens = [len(group_indices[g]) for g in candidates]
j = np.argmin(lens)
v, group_idx = group_values.pop(j)
assert group_idx == candidates[j]
bisect.insort(group_values, (float(values[idx] + v), group_idx))
group_indices[group_idx].append(idx)
return group_indices

View File

@ -268,6 +268,23 @@ class MemoryNameRecordRepository(NameRecordRepository):
class NfsNameRecordRepository(NameRecordRepository):
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
LOCK_FILE = f"{cluster_spec.fileroot}/name_resolve/LOCK"
os.makedirs(RECORD_ROOT, exist_ok=True)
@staticmethod
def locked(fn: Callable) -> Callable:
def fn_(*args, **kwargs):
import fcntl
with open(NfsNameRecordRepository.LOCK_FILE, "w") as fd:
fcntl.flock(fd, fcntl.LOCK_EX)
try:
res = fn(*args, **kwargs)
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
return res
return fn_
def __init__(self, **kwargs):
self.__to_delete = set()
@ -280,6 +297,7 @@ class NfsNameRecordRepository(NameRecordRepository):
def __file_path(name):
return os.path.join(NfsNameRecordRepository.__dir_path(name), "ENTRY")
@locked
def add(
self,
name,
@ -299,6 +317,7 @@ class NfsNameRecordRepository(NameRecordRepository):
if delete_on_exit:
self.__to_delete.add(name)
@locked
def delete(self, name):
path = self.__file_path(name)
if not os.path.isfile(path):
@ -322,12 +341,22 @@ class NfsNameRecordRepository(NameRecordRepository):
else:
logger.info("No such name resolve path: %s", dir_path)
@locked
def get(self, name):
path = self.__file_path(name)
if not os.path.isfile(path):
raise NameEntryNotFoundError(path)
with open(path, "r") as f:
return f.read().strip()
for _ in range(100):
# HACK: dealing with the possible OSError: Stale file handle
try:
with open(path, "r") as f:
return f.read().strip()
except OSError as e:
if e.errno == 116:
time.sleep(5e-3)
continue
raise e
raise RuntimeError("Failed to read value for %s" % name)
def get_subtree(self, name_root):
dir_path = self.__dir_path(name_root)

View File

@ -60,3 +60,7 @@ def distributed_local_peer(experiment_name, trial_name, host_name, model_name):
def distributed_master(experiment_name, trial_name, model_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/distributed/master/{model_name}"
def model_version(experiment_name, trial_name, model_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/model_version/{model_name}"

View File

@ -6,12 +6,27 @@ import socket
from contextlib import closing
def find_free_port():
"""From stackoverflow Issue 1365265."""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def find_free_port(low=1, high=65536, exclude_ports=None):
"""Find a free port within the specified range, excluding certain ports."""
if exclude_ports is None:
exclude_ports = set()
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
return port
def find_multiple_free_ports(count, low=1, high=65536):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(low, high, exclude_ports=free_ports)
free_ports.add(port)
return list(free_ports)
def gethostname():

View File

@ -2,11 +2,13 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
from importlib.metadata import version
from typing import List
import realhf.api.core.model_api as model_api
from packaging.version import Version
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
from realhf.base import logging
logger = logging.getLogger(__name__)
@ -27,7 +29,22 @@ def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
if vllm.hybrid_train and not vllm.enforce_eager:
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
logger.warning(
"For version < 0.7.0, vLLM hybrid_train requires eager mode to be enabled. "
"The user has the responsibility to ensure the version is correct."
)
def check_valid_sglang(
role: str, sglang: SGLangConfig, rpc_allocs: List[RPCAllocation]
):
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
if sglang.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
logger.warning(
"SGLang hybrid_train is enabled, but no training RPCs are found."
)
if sglang.hybrid_train and not sglang.disable_cuda_graph:
raise ValueError("SGLang hybrid_train requires CUDA graph to be disabled.")
def check_valid_optimizer(model: ModelTrainEvalConfig):

View File

@ -4,6 +4,7 @@
import dataclasses
import itertools
import os
from collections import defaultdict
from typing import *
@ -51,6 +52,7 @@ from realhf.experiments.common.check import (
check_valid_model_and_path,
check_valid_optimizer,
check_valid_parallel_batch_size,
check_valid_sglang,
check_valid_vllm,
)
from realhf.experiments.common.utils import (
@ -69,7 +71,7 @@ import realhf.api.from_hf # isort:skip
logger = logging.getLogger("CommonExperimentConfig", "colored")
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
@dataclasses.dataclass
@ -108,7 +110,7 @@ class CommonExperimentConfig(Experiment):
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
- A regex pattern like ``vllm.{IdentPara}+{IdentPara}``\: Decoupled generation (vLLM) and training allocations with correspnding identical parallelization strategies. Note that the pipeline parallel degree of vLLM can only be 1.
- A regex pattern like ``{vllm|sglang}.{IdentPara}+{IdentPara}``\: Decoupled generation and training allocations with correspnding identical parallelization strategies.
- Key-value pairs with MFC names and their parallel strategies in the whole cluster, e.g., ``actor_gen:d4m2p1,*:d2p2m2`` specifies a ``d4m2p1`` strategy for actor geneartion and ``d2p2m2`` for other MFCs in a world of 8 GPUs.
@ -363,6 +365,10 @@ class CommonExperimentConfig(Experiment):
),
)
@property
def _allocation_mode(self):
return AllocationMode.from_str(self.allocation_mode)
def _get_rpc_allocations(self) -> List[RPCAllocation]:
if self.allocation_mode == "manual" and self.nodelist is None:
logger.warning(
@ -373,9 +379,7 @@ class CommonExperimentConfig(Experiment):
f"and n_gpus_per_node {self.n_gpus_per_node}."
)
self.__check_legal_allocation_options()
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
self._check_legal_allocation_options()
rpcs = self.rpcs
if self.allocation_mode == "search":
@ -452,9 +456,9 @@ class CommonExperimentConfig(Experiment):
raise ValueError(
"The multiplication of 3D parallel degrees "
"does not equal to the number of gpus. "
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
"so their summation should be equal to the total number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={gdp}, vllm.pp={gpp}, vllm.mp={gmp}, "
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
alloc = RPCAllocation(
@ -535,7 +539,7 @@ class CommonExperimentConfig(Experiment):
def _get_model_worker_configs(
self, rpc_allocs: List[RPCAllocation]
) -> List[ModelWorker]:
self.__run_model_sanity_check(rpc_allocs)
self._run_model_sanity_check(rpc_allocs)
model_worker = []
shard_counter = defaultdict(lambda: 0)
@ -557,7 +561,7 @@ class CommonExperimentConfig(Experiment):
tokenizer_name_or_path=self.tokenizer_name_or_path,
)
# vLLM enabled model worker, shortcut case
# decoupled allocation, shortcut case
if (
self._allocation_mode.is_decoupled()
and self.gen_device_mesh.mapping[i, j]
@ -574,22 +578,32 @@ class CommonExperimentConfig(Experiment):
is_train=False,
)
model_cfg = self.models[model_name.role]
global vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
gen_backend_name = ""
if self._allocation_mode.is_decoupled_vllm():
gen_backend_name = "vllm"
elif self._allocation_mode.is_decoupled_sglang():
gen_backend_name = "sglang"
backend_cfg = getattr(model_cfg, gen_backend_name)
global GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
if (
model_cfg.vllm.hybrid_train
and not vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
backend_cfg.hybrid_train
and not GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
):
logger.warning(
"vLLM hybrid_train=True takes no effect for the decoupled allocation"
"hybrid_train=True takes no effect for the decoupled allocation"
)
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
model_cfg.vllm.hybrid_train = False
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
backend_cfg.hybrid_train = False
if gen_backend_name == "vllm":
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
elif gen_backend_name == "sglang":
check_valid_sglang(model_name.role, model_cfg.sglang, rpc_allocs)
shard_idx = shard_counter[model_name]
vllm_dict_args: Dict[str, Any] = OmegaConf.to_container(
model_cfg.vllm, resolve=True
)
dict_args: Dict[str, Any] = asdict(backend_cfg)
mw.shards.append(
StandaloneModelShardAbstraction(
id=ModelShardID(
@ -603,11 +617,11 @@ class CommonExperimentConfig(Experiment):
"tokenizer", args=dict(tokenizer_path=model_cfg.path)
),
backend=ModelBackendAbstraction(
"vllm",
gen_backend_name,
args=dict(
model_path=model_cfg.path,
dtype="bfloat16" if model_cfg.bf16 else "float16",
**vllm_dict_args,
**dict_args,
),
),
)
@ -684,29 +698,27 @@ class CommonExperimentConfig(Experiment):
rpc.is_generate() for rpc in rpcs
):
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
assert (
not model_cfg.sglang.hybrid_train
), "vLLM and SGLang cannot be enabled at the same time"
dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
backend = ModelBackendAbstraction(
"vllm",
args=dict(
model_path=model_cfg.path,
**vllm_dict_args,
**dict_args,
),
)
elif model_cfg.sglang.hybrid_train and any(
rpc.is_generate() for rpc in rpcs
):
raise NotImplementedError(
"SGLang hybrid_train=True is not supported yet."
)
else:
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
"vllm",
"sglang",
]:
print(rpcs, model_name, backend.type_)
raise ValueError(
"vLLM or SGLang is not enabled for generation. "
"This behavior has been deprecated. "
"Please set model.vllm.hybrid_train=True "
"or model.sglang.hybrid_train=True."
)
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
if mapping[i, j]:
shard_idx = shard_counter[model_name]
mw.shards.append(
@ -749,7 +761,7 @@ class CommonExperimentConfig(Experiment):
evaluator=self.auto_eval_config,
)
def __check_legal_allocation_options(self):
def _check_legal_allocation_options(self):
if self.n_nodes > 1 and self.mode == "local":
raise ValueError(
"Cannot run multi-node experiment in local mode, "
@ -791,7 +803,7 @@ class CommonExperimentConfig(Experiment):
f"RPC {rpc.name} model name {rpc.model_name.role} is not in models."
)
def __run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
def _run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
for alloc in rpc_allocs:
check_valid_parallel_batch_size(alloc)
for role, model in self.models.items():

View File

@ -90,6 +90,7 @@ class PPOHyperparameters:
eps_clip: float = 0.2
value_eps_clip: float = 0.2
disable_value: bool = False
recompute_logprob: bool = False
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
@ -189,15 +190,14 @@ class PPOMATHConfig(CommonExperimentConfig):
default_factory=ModelTrainEvalConfig
)
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
# for manual allocation only
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
default_factory=PromptOnlyDatasetConfig
@ -248,30 +248,27 @@ class PPOMATHConfig(CommonExperimentConfig):
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
models = {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
}
if self.ppo.disable_value:
return {
"actor": self.actor,
# "critic": self.critic,
"ref": self.ref,
# "reward": self.rew,
}
else:
return {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
models.pop("critic")
return models
@property
def rpcs(self):
if (
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
(self._allocation_mode.is_decoupled_vllm() or self.actor.vllm.hybrid_train)
and self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
> self.actor.vllm.max_seq_len_to_capture
and not self.actor.vllm.enforce_eager
):
raise RuntimeError(
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
f"smaller than the prompt length + generation length "
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
raise RuntimeError(
@ -334,6 +331,14 @@ class PPOMATHConfig(CommonExperimentConfig):
+ 128,
),
)
rollout_output_keys = [
"seq_no_eos_mask",
"packed_input_ids",
"packed_logprobs",
"prompt_mask",
]
if self.ppo.recompute_logprob:
rollout_output_keys.remove("packed_logprobs")
rollout = MFCDef(
name="actor_gen",
model_name="actor",
@ -343,32 +348,27 @@ class PPOMATHConfig(CommonExperimentConfig):
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts"],
output_keys=[
"seq_no_eos_mask",
"packed_input_ids",
"packed_logprobs",
"prompt_mask",
],
output_keys=rollout_output_keys,
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
mb_spec=self.rew_inf.mb_spec,
actor_inf = MFCDef(
name="actor_inf",
model_name="actor",
mb_spec=self.actor_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
output_keys=["rewards", "dense_rewards"],
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_input_ids"],
output_keys=["packed_logprobs"],
output_key_remap=dict(logprobs="packed_logprobs"),
n_seqs=self.dataset.train_bs_n_seqs,
)
# add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
inf_ref_outputs = ["logprobs", "rewards", "dense_rewards"]
ref_interface = copy.deepcopy(actor_interface)
ref_interface.type_ = "ref_rw"
ref_interface.args["enable_save"] = False
@ -385,6 +385,7 @@ class PPOMATHConfig(CommonExperimentConfig):
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=inf_ref_outputs,
output_key_remap=dict(logprobs="packed_ref_logprobs"),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -450,47 +451,40 @@ class PPOMATHConfig(CommonExperimentConfig):
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
)
rpcs = {
"actor_gen": rollout,
"actor_train": train_actor,
"critic_inf": inf_values,
"critic_train": train_critic,
"ref_inf": inf_ref_logits,
"actor_inf": actor_inf,
"rew_inf": inf_reward,
}
if self.ppo.disable_value:
return {
"actor_gen": rollout,
"actor_train": train_actor,
# "critic_inf": inf_values,
# "critic_train": train_critic,
# "ref_inf": inf_ref_logits,
# "rew_inf": inf_reward,
"ref_rw": inf_ref_logits,
}
else:
return {
"actor_gen": rollout,
"actor_train": train_actor,
"critic_inf": inf_values,
"critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
}
rpcs.pop("critic_inf")
rpcs.pop("critic_train")
if not self.ppo.recompute_logprob:
rpcs.pop("actor_inf")
return rpcs
@property
def allocations(self):
allocs = {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
"critic_inf": self.critic_inf,
"critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"actor_inf": self.actor_inf,
"rew_inf": self.rew_inf,
}
if self.ppo.disable_value:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
# "critic_inf": self.critic_inf,
# "critic_train": self.critic_train,
# "ref_inf": self.ref_inf,
# "rew_inf": self.rew_inf,
"ref_rw": self.ref_inf,
}
else:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
"critic_inf": self.critic_inf,
"critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
}
allocs.pop("critic_inf")
allocs.pop("critic_train")
if not self.ppo.recompute_logprob:
allocs.pop("actor_inf")
return allocs
@property
def datasets(self):

View File

@ -110,32 +110,40 @@ def make_inf_backend_config(
def resolve_replica_ids(
rpc_allocs: List[RPCAllocation], models: Dict[str, ModelTrainEvalConfig]
):
role_cnt = collections.defaultdict(int)
first_device_mesh = dict()
first_parallel = dict()
first_rpc = dict()
role_rpcs = collections.defaultdict(list)
for alloc in rpc_allocs:
rpc = alloc.rpc
if rpc.role not in first_device_mesh:
first_device_mesh[rpc.role] = alloc.device_mesh
first_parallel[rpc.role] = alloc.parallel
first_rpc[rpc.role] = rpc
role_rpcs[rpc.role].append(alloc)
for role, allocs in role_rpcs.items():
cnt = len(allocs)
if cnt == 1:
allocs[0].rpc.model_name = ModelName(role, 0)
continue
model_cfg = models[rpc.role]
if (rpc.is_train() and first_rpc[rpc.role].is_generate()) or (
rpc.is_generate() and first_rpc[rpc.role].is_train()
):
if model_cfg.vllm.hybrid_train:
role_cnt[rpc.role] += 1
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
rpcs = [alloc.rpc for alloc in allocs]
if any(rpc.is_train() for rpc in rpcs):
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_train())
elif any(rpc.is_inference() for rpc in rpcs):
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_inference())
else:
main_alloc = allocs[0]
main_alloc.rpc.model_name = ModelName(role, 0)
i = 1
for alloc in allocs:
if alloc.rpc.name == main_alloc.rpc.name:
continue
if alloc.device_mesh != first_device_mesh[rpc.role] or not parallelism_eq(
alloc.parallel, first_parallel[rpc.role]
):
role_cnt[rpc.role] += 1
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
continue
assert rpc.model_name.replica_id == 0
same_alloc = alloc.device_mesh == main_alloc.device_mesh and parallelism_eq(
alloc.parallel, main_alloc.parallel
)
if not same_alloc or (
alloc.rpc.is_generate()
and main_alloc.rpc.is_train()
and (models[role].vllm.hybrid_train or models[role].sglang.hybrid_train)
):
alloc.rpc.model_name = ModelName(role, i)
i += 1
else:
alloc.rpc.model_name = ModelName(role, 0)
def resolve_rpc_hooks(
@ -207,6 +215,7 @@ class AllocationType(enum.Enum):
HEURISTIC = 4
SEARCH = 5
DECOUPLED_SGLANG = 6
DECOUPLED_MOCK = 7
@dataclasses.dataclass
@ -218,6 +227,7 @@ class AllocationMode:
return self.type_ in [
AllocationType.DECOUPLED_vLLM,
AllocationType.DECOUPLED_SGLANG,
AllocationType.DECOUPLED_MOCK,
]
def is_decoupled_vllm(self):
@ -226,6 +236,9 @@ class AllocationMode:
def is_decoupled_sglang(self):
return self.type_ == AllocationType.DECOUPLED_SGLANG
def is_decoupled_mock(self):
return self.type_ == AllocationType.DECOUPLED_MOCK
def is_global_hybrid(self):
return self.type_ == AllocationType.GLOBAL_HYBRID
@ -246,6 +259,8 @@ class AllocationMode:
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
elif "sglang" in allocation_mode:
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
elif "mock" in allocation_mode:
return cls(AllocationType.DECOUPLED_MOCK, alloc_decoupled)
if alloc_3d:
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
if alloc_hybrid:
@ -272,7 +287,7 @@ class AllocationMode:
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
r"(?:(?:vllm|sglang|mock)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang|mock)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:

View File

@ -14,21 +14,50 @@ from typing import *
import torch
import torch.distributed as dist
import transformers
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from realhf.api.core import model_api
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.quickstart.model import (
DistributedDataParallelConfig,
MegatronConfig,
OptimizerConfig,
)
from realhf.base import constants, logging
from realhf.base.datapack import flat2d
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
from realhf.impl.model.modules.mlp import get_activation_fn
from realhf.impl.model.nn.flatten_param import ContiguousParamSpec
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
try:
# Monkey patch
import megatron.core.optimizer as mcore_optim
class DistributedOptimizer(mcore_optim.DistributedOptimizer):
def get_model_parallel_group(self):
return constants.parallelism_group()
mcore_optim.DistributedOptimizer = DistributedOptimizer
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel import (
DistributedDataParallel,
)
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
from megatron.core.optimizer.clip_grads import clip_grad_norm_fp32, count_zeros_fp32
from megatron.core.optimizer.optimizer_config import (
OptimizerConfig as MegatronOptimizerConfig,
)
from megatron.core.transformer.transformer_config import (
TransformerConfig as MegatronTransformerConfig,
)
megatron_available = True
except (ModuleNotFoundError, ImportError):
# importing megatron.core in CPU container will fail due to the requirement of apex
# Here class types must be defined for type hinting
@ -42,19 +71,19 @@ except (ModuleNotFoundError, ImportError):
pass
from realhf.api.core import model_api
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.quickstart.model import MegatronConfig, OptimizerConfig
from realhf.base import constants, logging
from realhf.base.datapack import flat2d
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
from realhf.impl.model.modules.mlp import get_activation_fn
from realhf.impl.model.nn.flatten_param import ContiguousParamSpec
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
if megatron_available:
try:
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
use_old_megatron = False
except (ModuleNotFoundError, ImportError):
# The above object is available in 0.9.0 but missing in 0.6.0
from realhf.impl.model.backend.thirdparty.megatron.v0_6_0.lr_schduler import (
OptimizerParamScheduler,
)
use_old_megatron = True
WITHIN_MEGATRON_CONTEXT = False
@ -122,6 +151,7 @@ def megatron_ctx():
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = constants.model_parallel_group()
g = constants.data_parallel_group()
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = g
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = g
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
grid.get_data_parallel_group_gloo()
)
@ -166,462 +196,62 @@ class MegatronEngine:
self.ddp.zero_grad_buffer()
self.optim.zero_grad(set_to_none=set_to_none)
def _all_reduce_layernorm_grads(self):
if not (
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
):
return
real_model: ReaLModel = self.ddp.module
grads = []
for i in range(real_model.layer_idx_start, real_model.layer_idx_end):
if i == 0:
continue
elif i == real_model.config.n_layers + 1:
continue
else:
assert 0 < i < real_model.config.n_layers + 1
layer: ReaLModelBlock = real_model.layers[
i - real_model.layer_idx_start
]
grads.append(layer.attn.c_attn.ln.weight.main_grad)
if getattr(layer.attn.c_attn.ln, "bias", None) is not None:
grads.append(layer.attn.c_attn.ln.bias.main_grad)
grads.append(layer.mlp.ln.weight.main_grad)
if getattr(layer.mlp.ln, "bias", None) is not None:
grads.append(layer.mlp.ln.bias.main_grad)
if i == real_model.config.n_layers:
grads.append(layer.ln_f.weight.main_grad)
if getattr(layer.ln_f, "bias", None) is not None:
grads.append(layer.ln_f.bias.main_grad)
# Adopted from Megatron-LM/megatron/training/optimizer_param_scheduler.py
class OptimizerParamScheduler(object):
"""Anneals learning rate and weight decay.
assert all(x is not None for x in grads)
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, group=constants.model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
Adopted from Megatron-LM. This class is not included in
megatron.core, so we have to copy-paste it here.
"""
def _all_reduce_word_embedding_grads(self):
real_model: ReaLModel = self.ddp.module
if not real_model.config.tied_embedding or real_model.config.is_critic:
return
pp_size = constants.pipe_parallel_world_size()
pp_rank = constants.pipe_parallel_rank()
if pp_size == 1:
return
if pp_rank not in [0, pp_size - 1]:
return
def __init__(
self,
optimizer,
init_lr,
max_lr,
min_lr,
lr_warmup_steps,
lr_decay_steps,
lr_decay_style,
start_wd,
end_wd,
wd_incr_steps,
wd_incr_style,
use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False,
):
# Class values.
self.optimizer = optimizer
self.init_lr = init_lr
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
assert self.init_lr <= self.max_lr
self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0
self.lr_decay_steps = lr_decay_steps
assert self.lr_decay_steps > 0
assert self.lr_warmup_steps < self.lr_decay_steps
self.lr_decay_style = lr_decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, (
"both override and " "use-checkpoint are set."
)
# Set the learning rate
self.step(0)
self.log_rank_0("> learning rate decay style: {}".format(self.lr_decay_style))
def log_rank_0(self, msg):
if constants.parallelism_rank() == 0:
logger.info(msg)
def get_wd(self):
"""Weight decay incr functions."""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == "constant":
assert self.start_wd == self.end_wd
return self.end_wd
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == "linear":
coeff = incr_ratio
elif self.wd_incr_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
if pp_rank == 0:
grad = real_model.layers[0].wte.weight.main_grad
else:
raise Exception(
"{} weight decay increment style is not supported.".format(
self.wd_incr_style
)
)
grad = real_model.layers[-1].weight.main_grad
return self.start_wd + coeff * delta_wd
dist.all_reduce(grad, group=constants.grid().embedding_proc_group)
def get_lr(self, param_group):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
max_lr = param_group.get("max_lr", self.max_lr)
min_lr = param_group.get("min_lr", self.min_lr)
# Use linear warmup for the initial part.
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.init_lr + (
(max_lr - self.init_lr)
* float(self.num_steps)
/ float(self.lr_warmup_steps)
)
# If the learning rate is constant, just return the initial value.
if self.lr_decay_style == "constant":
return max_lr
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
if self.num_steps > self.lr_decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
if self.lr_decay_style == "inverse-square-root":
warmup_steps = max(self.lr_warmup_steps, 1)
num_steps = max(self.num_steps, 1)
lr = max_lr * warmup_steps**0.5 / (num_steps**0.5)
return max(min_lr, lr)
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
if self.lr_decay_style == "linear":
coeff = 1.0 - decay_ratio
elif self.lr_decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception(
"{} decay style is not supported.".format(self.lr_decay_style)
)
return min_lr + coeff * delta_lr
def step_absolute(self, num_steps):
"""Set lr for all parameters groups."""
if num_steps is None:
self.num_steps += 1
else:
self.num_steps = num_steps
new_wd = self.get_wd()
for param_group in self.optimizer.param_groups:
new_lr = self.get_lr(param_group)
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
def step(self, increment):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_wd = self.get_wd()
for param_group in self.optimizer.param_groups:
new_lr = self.get_lr(param_group)
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
def state_dict(self):
state_dict = {
"max_lr": self.max_lr,
"lr_warmup_steps": self.lr_warmup_steps,
"num_steps": self.num_steps,
"lr_decay_style": self.lr_decay_style,
"lr_decay_steps": self.lr_decay_steps,
"min_lr": self.min_lr,
"start_wd": self.start_wd,
"end_wd": self.end_wd,
"wd_incr_style": self.wd_incr_style,
"wd_incr_steps": self.wd_incr_steps,
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_opt_param_scheduler:
self.log_rank_0(" > overriding {} value to {}".format(name, cls_value))
return cls_value
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, (
f"OptimizerParamScheduler: class input value {cls_value} and checkpoint"
f"value {sd_value} for {name} do not match"
)
self.log_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
return sd_value
def load_state_dict(self, sd):
if "start_lr" in sd:
max_lr_ = sd["start_lr"]
else:
max_lr_ = sd["max_lr"]
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
self.min_lr = self._check_and_set(
self.min_lr, sd["min_lr"], "minimum learning rate"
)
if "warmup_iter" in sd:
lr_warmup_steps_ = sd["warmup_iter"]
elif "warmup_steps" in sd:
lr_warmup_steps_ = sd["warmup_steps"]
else:
lr_warmup_steps_ = sd["lr_warmup_steps"]
self.lr_warmup_steps = self._check_and_set(
self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations"
)
if "end_iter" in sd:
lr_decay_steps_ = sd["end_iter"]
elif "decay_steps" in sd:
lr_decay_steps_ = sd["decay_steps"]
else:
lr_decay_steps_ = sd["lr_decay_steps"]
self.lr_decay_steps = self._check_and_set(
self.lr_decay_steps, lr_decay_steps_, "total number of iterations"
)
if "decay_style" in sd:
lr_decay_style_ = sd["decay_style"]
else:
lr_decay_style_ = sd["lr_decay_style"]
self.lr_decay_style = self._check_and_set(
self.lr_decay_style, lr_decay_style_, "learning rate decay style"
)
if "num_iters" in sd:
num_steps = sd["num_iters"]
else:
num_steps = sd["num_steps"]
self.step(increment=num_steps)
if "start_wd" in sd:
self.start_wd = self._check_and_set(
self.start_wd, sd["start_wd"], "start weight decay"
)
self.end_wd = self._check_and_set(
self.end_wd, sd["end_wd"], "end weight decay"
)
self.wd_incr_steps = self._check_and_set(
self.wd_incr_steps,
sd["wd_incr_steps"],
"total number of weight decay iterations",
)
self.wd_incr_style = self._check_and_set(
self.wd_incr_style,
sd["wd_incr_style"],
"weight decay incr style",
)
@torch.no_grad()
def _step_megatron_distrib_optimizer_internal(optim: DistributedOptimizer):
# NOTE: patching this function to use the correct model parallel group
optim._copy_model_grads_to_main_grads()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if optim.grad_scaler:
def _unscale_main_grads_and_check_for_nan(optim: DistributedOptimizer):
# Collect main grads.
main_grads = optim._collect_main_grad_data_for_unscaling()
# Reset found inf.
optim.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, optim.found_inf, optim.grad_scaler.inv_scale
)
# Update across all model parallel instances.
dist.all_reduce(
optim.found_inf,
op=dist.ReduceOp.MAX,
group=constants.parallelism_group(),
)
# Check for nan.
found_inf_flag = optim.found_inf.item() > 0
return found_inf_flag
# Unscale and check for inf/nan.
found_inf_flag = _unscale_main_grads_and_check_for_nan(optim)
# We are done with scaling gradients
# so we can update the loss scale.
optim.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
optim.update_successful, grad_norm, num_zeros_in_grad = (
False,
None,
None,
)
return optim.update_successful, grad_norm, num_zeros_in_grad
def clip_grad_norm(optim: DistributedOptimizer, clip_grad: float) -> float:
"""Compute grad norm."""
params = optim.get_parameters()
grads_for_norm = optim.get_main_grads_for_grad_norm()
return clip_grad_norm_fp32(
params,
grads_for_norm,
clip_grad,
model_parallel_group=constants.parallelism_group(),
)
def count_zeros(optim: DistributedOptimizer) -> float:
"""Count number of zeros in model's gradients."""
params = optim.get_parameters()
return count_zeros_fp32(
params,
model_parallel_group=constants.parallelism_group(),
)
# Clip the main gradients.
grad_norm = None
if optim.config.clip_grad > 0.0:
grad_norm = clip_grad_norm(optim, optim.config.clip_grad)
# Count the zeros in the grads.
num_zeros_in_grad = None
if optim.config.log_num_zeros_in_grad:
num_zeros_in_grad = count_zeros(optim)
# Step the optimizer.
optim.optimizer.step()
# Update params from main params.
optim._copy_main_params_to_model_params()
# Successful update.
optim.update_successful, grad_norm, num_zeros_in_grad = (
True,
grad_norm,
num_zeros_in_grad,
)
return optim.update_successful, grad_norm, num_zeros_in_grad
def step_megatron_distrb_optimizer(optim: DistributedOptimizer):
optim.update_successful, grad_norm, num_zeros_in_grad = (
_step_megatron_distrib_optimizer_internal(optim)
)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# call to _gather_all_model_params is a no-op: the first all-gather is launched
# asynchronously in the next optimizer.zero_grad() call and subsequent all-gathers
# are launched in the forward pre-hook.
optim._reset_metadata_and_sync_gather_all_model_params(force_sync=False)
return optim.update_successful, grad_norm, num_zeros_in_grad
def _flatten_dense_tensors(tensors):
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
same dense type.
Since inputs are dense, the resulting tensor will be a concatenated 1D
buffer. Element-wise operation on this buffer will be equivalent to
operating individually.
Args:
tensors (Iterable[Tensor]): dense tensors to flatten.
Returns:
A contiguous 1D buffer containing input tensors.
"""
return torch._C._nn.flatten_dense_tensors(tensors)
def _unflatten_dense_tensors(flat, tensors):
"""View a flat buffer using the sizes of tensors. Assume that tensors are
of same dense type, and that flat is given by _flatten_dense_tensors.
Args:
flat (Tensor): flattened dense tensors to unflatten.
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
unflatten flat.
Returns:
Unflattened dense tensors with sizes same as tensors and values from
flat.
"""
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
def megatron_all_reduce_layernorm_grads(engine: MegatronEngine):
if not (
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
):
return
real_model: ReaLModel = engine.ddp.module
grads = []
for i in range(real_model.layer_idx_start, real_model.layer_idx_end):
if i == 0:
continue
elif i == real_model.config.n_layers + 1:
continue
else:
assert 0 < i < real_model.config.n_layers + 1
layer: ReaLModelBlock = real_model.layers[i - real_model.layer_idx_start]
grads.append(layer.attn.c_attn.ln.weight.main_grad)
if getattr(layer.attn.c_attn.ln, "bias", None) is not None:
grads.append(layer.attn.c_attn.ln.bias.main_grad)
grads.append(layer.mlp.ln.weight.main_grad)
if getattr(layer.mlp.ln, "bias", None) is not None:
grads.append(layer.mlp.ln.bias.main_grad)
if i == real_model.config.n_layers:
grads.append(layer.ln_f.weight.main_grad)
if getattr(layer.ln_f, "bias", None) is not None:
grads.append(layer.ln_f.bias.main_grad)
assert all(x is not None for x in grads)
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, group=constants.model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def megatron_all_reduce_word_embedding_grads(engine: MegatronEngine):
real_model: ReaLModel = engine.ddp.module
if not real_model.config.tied_embedding or real_model.config.is_critic:
return
pp_size = constants.pipe_parallel_world_size()
pp_rank = constants.pipe_parallel_rank()
if pp_size == 1:
return
if pp_rank not in [0, pp_size - 1]:
return
if pp_rank == 0:
grad = real_model.layers[0].wte.weight.main_grad
else:
grad = real_model.layers[-1].weight.main_grad
dist.all_reduce(grad, group=constants.grid().embedding_proc_group)
def finalize_grads_megatron(engine: MegatronEngine):
engine.ddp.finish_grad_sync()
megatron_all_reduce_layernorm_grads(engine)
megatron_all_reduce_word_embedding_grads(engine)
def finalize_grads(self):
self.ddp.finish_grad_sync()
self._all_reduce_layernorm_grads()
self._all_reduce_word_embedding_grads()
@dataclasses.dataclass
@ -681,7 +311,7 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
step_id: int,
):
# self.engine.ddp.start_grad_sync()
finalize_grads_megatron(self.engine)
self.engine.finalize_grads()
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
def _exec_optimizer_step(
@ -692,16 +322,12 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
micro_batch_id: int,
step_id: int,
):
if isinstance(self.engine.optim, DistributedOptimizer):
update_successful, grad_norm, num_zeros_in_grad = (
step_megatron_distrb_optimizer(self.engine.optim)
)
else:
update_successful, grad_norm, num_zeros_in_grad = self.engine.optim.step()
update_successful, grad_norm, num_zeros_in_grad = self.engine.optim.step()
version_steps = tensor_buffer.get("version_steps", 0)
if update_successful:
self.engine.lr_scheduler.step_absolute(version_steps)
incr = version_steps - self.engine.lr_scheduler.num_steps
self.engine.lr_scheduler.step(incr)
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
logger.info(
f"Model name {constants.model_name()}, "
@ -824,7 +450,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
for k, v in _stat.items():
stat[k] += v
finalize_grads_megatron(self.engine)
self.engine.finalize_grads()
self._step(version_steps)
return stat
@ -834,12 +460,14 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
):
return self.inf_engine.forward(
input_=input_,
mb_spec=mb_spec,
output_seqlens=output_seqlens,
post_hook=post_hook,
aggregate_fn=aggregate_fn,
)
@ -864,14 +492,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
# wrapper for profiler
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
def _step(self, version_steps):
if isinstance(self.engine.optim, DistributedOptimizer):
update_successful, grad_norm, _ = step_megatron_distrb_optimizer(
self.engine.optim
)
else:
update_successful, grad_norm, _ = self.engine.optim.step()
update_successful, grad_norm, _ = self.engine.optim.step()
if update_successful:
self.engine.lr_scheduler.step_absolute(version_steps)
incr = version_steps - self.engine.lr_scheduler.num_steps
self.engine.lr_scheduler.step(incr)
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
logger.info(
f"Megatron backend update success? {update_successful}. "
@ -884,10 +508,6 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
@dataclasses.dataclass
class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
bf16: bool = False
zero_stage: int = dataclasses.field(
metadata={"choices": [0, 1, 2, 3]},
default=2,
)
optimizer: OptimizerConfig = dataclasses.field(default_factory=OptimizerConfig)
def _initialize(
@ -896,21 +516,32 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
module = model.module
if not isinstance(module, ReaLModel):
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
if isinstance(self.ddp, dict):
self.ddp = DistributedDataParallelConfig(**self.ddp)
with megatron_ctx():
module = DistributedDataParallel(
config=get_megatron_transformer_config(module.config),
module=module,
data_parallel_group=constants.data_parallel_group(),
accumulate_allreduce_grads_in_fp32=self.accumulate_allreduce_grads_in_fp32,
overlap_grad_reduce=self.overlap_grad_reduce,
use_distributed_optimizer=self.zero_stage > 0,
expert_data_parallel_group=None,
disable_bucketing=False,
check_for_nan_in_grad=False,
)
if use_old_megatron:
module = DistributedDataParallel(
config=get_megatron_transformer_config(module.config),
module=module,
data_parallel_group=constants.data_parallel_group(),
accumulate_allreduce_grads_in_fp32=self.ddp.grad_reduce_in_fp32,
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
expert_data_parallel_group=None,
disable_bucketing=False,
check_for_nan_in_grad=self.ddp.check_for_nan_in_grad,
bucket_size=self.ddp.bucket_size,
)
else:
module = DistributedDataParallel(
config=get_megatron_transformer_config(module.config),
ddp_config=self.ddp,
module=module,
disable_bucketing=False,
)
real_model: ReaLModel = module.module
if self.zero_stage > 0:
if self.ddp.use_distributed_optimizer:
# Remap parameters.
assert len(module.buffers) == 1
param_grad_buf: ParamAndGradBuffer = module.buffers[0]
@ -946,22 +577,23 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
opt_cfg = MegatronOptimizerConfig(
optimizer=self.optimizer.type,
bf16=self.bf16,
fp16=not self.bf16,
lr=lr,
min_lr=self.optimizer.min_lr_ratio * lr,
weight_decay=wd,
params_dtype=real_model.dtype,
initial_loss_scale=self.optimizer.initial_loss_scale,
adam_beta1=betas[0],
adam_beta2=betas[1],
adam_eps=self.optimizer.eps,
sgd_momentum=0.9,
use_distributed_optimizer=self.zero_stage > 0,
overlap_grad_reduce=self.overlap_grad_reduce,
overlap_param_gather=self.overlap_param_gather,
clip_grad=self.optimizer.gradient_clipping,
min_loss_scale=self.optimizer.min_loss_scale,
loss_scale_window=self.optimizer.loss_scale_window,
hysteresis=self.optimizer.hysteresis,
adam_beta1=betas[0],
adam_beta2=betas[1],
adam_eps=self.optimizer.eps,
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
overlap_param_gather=self.ddp.overlap_param_gather,
clip_grad=self.optimizer.gradient_clipping,
log_num_zeros_in_grad=False,
)
with megatron_ctx():
@ -1005,7 +637,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
# create circular references (grad -> param -> grad).
# Deleting models directly will not release the memory.
# We must disable hooks at first.
if self.zero_stage > 0 and self.overlap_param_gather:
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
optimizer.disable_pre_hook()
def save(self, model: model_api.Model, save_dir: str):

View File

@ -179,12 +179,14 @@ class MockTrainEngine(model_api.PipelinableEngine):
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
):
return self.inf_engine.forward(
input_=input_,
mb_spec=mb_spec,
output_seqlens=output_seqlens,
post_hook=post_hook,
aggregate_fn=aggregate_fn,
)

View File

@ -0,0 +1,453 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import dataclasses
import json
import os
import sys
import time
import traceback
from importlib.metadata import version
from typing import Dict, List, Tuple
import aiohttp
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from packaging.version import Version
from tqdm.asyncio import tqdm
from realhf.api.core import data_api
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
FinetuneSpec,
GenerationHyperparameters,
LLMAPIClient,
Model,
ModelBackend,
PipelinableEngine,
register_backend,
)
from realhf.api.quickstart.model import SGLangConfig
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
logger = logging.getLogger("SGLang backend")
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
class SGLangAPIClient(LLMAPIClient):
async def _do_generate(
self, req: APIGenerateInput, stream: bool = False
) -> APIGenerateOutput:
gconfig = req.gconfig
sample_params = {
"n": gconfig.n,
"top_p": gconfig.top_p,
"top_k": gconfig.top_k,
"max_new_tokens": gconfig.max_new_tokens,
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
"stop_token_ids": req.stop_token_ids,
}
payload = {
"input_ids": req.prompt_ids,
"sampling_params": sample_params,
"return_logprob": req.return_logprob,
"stream": stream,
}
output = APIGenerateOutput.from_input(req)
# The following code is partially adopted from sglang/bench_serving.py
output_ids = []
output_logprobs = []
finish_reason = {}
ttft = 0.0
latency = float("inf")
st = time.perf_counter()
most_recent_timestamp = st
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=None)
try:
async with self.session.post(
url=self.generate_url,
json=payload,
timeout=timeout,
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["token_ids"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output_ids = data["token_ids"]
finish_reason = data["meta_info"]["finish_reason"]
output_logprobs = data["meta_info"][
"output_token_logprobs"
]
assert finish_reason["type"] in ["length", "stop"], finish_reason
output.output_logprobs = [x[0] for x in output_logprobs]
output.output_ids = output_ids
output.no_eos = finish_reason["type"] == "length"
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
raise RuntimeError(
f"SGLang generation request fails:\n{output.error}"
) from e
return output
async def async_update_weights_from_disk(self, path):
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=None)
async with self.session.post(
url=self.update_weights_url,
json=dict(model_path=path),
timeout=timeout,
) as response:
if response.status != 200:
raise RuntimeError("Update weights failed.")
def sglang_server_process(server_args_dict):
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
sglang_version = version("sglang")
if Version(sglang_version) < Version("0.4.3"):
from sglang.srt.server import launch_server
server_args_dict.pop("enable_nccl_nvls")
server_args_dict.pop("triton_attention_num_kv_splits")
server_args_dict.pop("cuda_graph_bs")
server_args_dict.pop("enable_memory_saver")
server_args_dict.pop("allow_auto_truncate")
server_args_dict.pop("return_hidden_states")
else:
from sglang.srt.entrypoints.http_server import launch_server
server_args = ServerArgs(**server_args_dict)
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, list(range(gpu_utils.gpu_count())))
)
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
class SGLangGenerationEngine(PipelinableEngine):
def __init__(
self,
server_args_dict: Dict,
hybrid_train: bool,
request_timeout: int = 1800,
):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
return
# Start the serving process
self.server_proc = mp.Process(
target=sglang_server_process,
args=(server_args_dict,),
)
self.server_proc.start()
self.base_url = f"http://{server_args_dict['host']}:{server_args_dict['port']}"
self.api_urls = {
"generate": f"{self.base_url}/generate",
"offload_weights": f"{self.base_url}/offload_weights",
"init_kv_cache": f"{self.base_url}/init_kv_cache",
"clear_kv_cache": f"{self.base_url}/clear_kv_cache",
"init_model_weights": f"{self.base_url}/init_model_weights",
"update_weights_from_disk": f"{self.base_url}/update_weights_from_disk",
}
asyncio.run(self.wait_server())
self.request_timeout = request_timeout
# offload weights/cache
self.hybrid_train = hybrid_train
dist.barrier(group=constants.model_parallel_cpu_group())
def __del__(self):
if hasattr(self, "server_proc"):
from sglang.srt.utils import kill_process_tree
self.server_proc.terminate()
kill_process_tree(os.getpid())
# NOTE: A placeholder function.
def train(self, mode: bool = True):
return self
# NOTE: A placeholder function.
def eval(self):
return self
async def wait_server(self):
# Wait until the server is launched
from sglang.srt.utils import kill_process_tree
from sglang.utils import get_exception_traceback
success = False
for _ in range(120):
await asyncio.sleep(1)
try:
res = requests.get(
self.base_url + "/get_model_info", timeout=5, headers={}
)
assert res.status_code == 200, f"{res=}, {res.text=}"
success = True
break
except (AssertionError, requests.exceptions.RequestException):
last_traceback = get_exception_traceback()
pass
if not success:
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
async def async_generate(
self,
input_: data_api.SequenceSample,
mb_spec: data_api.MicroBatchSpec,
tokenizer: transformers.PreTrainedTokenizerFast,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
stream: bool = False,
disable_tqdm: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
pbar = None if disable_tqdm else tqdm(total=input_.bs * gconfig.n)
async with SGLangAPIClient(
generate_url=self.api_urls["generate"],
update_weights_url=self.api_urls["update_weights_from_disk"],
) as client:
tasks = []
input_queries = []
for d in input_.unpack():
if len(d.seqlens["packed_input_ids"]) > 1:
raise RuntimeError(
f"sglang backend does not support grouped generation "
f"for now. Group size {len(d.seqlens['packed_input_ids'])}."
)
prompt_token_ids = d.data["packed_input_ids"].cpu().numpy().tolist()
qid = d.ids[0]
for group_idx in range(gconfig.n):
req = APIGenerateInput(
qid=qid,
group_idx=group_idx,
prompt_ids=prompt_token_ids,
input_ids=prompt_token_ids,
gconfig=gconfig.new(n=1),
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
return_logprob=True,
)
input_queries.append((qid, group_idx))
tasks.append(
client.async_add_generate_request(
req,
stream=stream,
)
)
outputs = {}
for r in asyncio.as_completed(tasks):
out = await r
outputs[(out.qid, out.group_idx)] = out
if pbar:
pbar.update(1)
if pbar is not None:
pbar.close()
results: List[APIGenerateOutput] = [outputs[key] for key in input_queries]
# Build the output: generated token ids, generated token scores,
# and logits mask (which will always be None in sglang).
batch_token_ids = []
batch_logprobs = []
max_seqlen = -1
for x in results:
max_seqlen = max(max_seqlen, len(x.output_ids))
batch_token_ids.append(x.output_ids)
batch_logprobs.append(x.output_logprobs)
# To be consistent with our internal implementation,
# we should pad generated tokens and logprobs
batch_token_ids = [
t + [tokenizer.pad_token_id] * (max_seqlen - len(t))
for t in batch_token_ids
]
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
return (
torch.tensor(
batch_token_ids, dtype=torch.long, device=constants.current_device()
),
torch.tensor(
batch_logprobs, dtype=torch.float32, device=constants.current_device()
),
None,
)
def generate(
self,
input_: data_api.SequenceSample,
mb_spec: data_api.MicroBatchSpec,
tokenizer: transformers.PreTrainedTokenizerFast,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
if gconfig.min_new_tokens != 0:
raise RuntimeError(
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
"because we force to skip_tokenizer_init."
)
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
return None, None, None
results = asyncio.run(
self.async_generate(
input_=input_,
mb_spec=mb_spec,
tokenizer=tokenizer,
gconfig=gconfig,
)
)
dist.barrier(group=constants.model_parallel_cpu_group())
return results
def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
return
async def _fn():
async with SGLangAPIClient(
generate_url=self.api_urls["generate"],
update_weights_url=self.api_urls["update_weights_from_disk"],
) as client:
await client.async_update_weights_from_disk(path)
asyncio.run(_fn())
dist.barrier(group=constants.model_parallel_cpu_group())
@dataclasses.dataclass
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
model_path: str = ""
dtype: str = "float16"
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
if constants.pipe_parallel_world_size() != 1:
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
raise RuntimeError(
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
)
additional_args = dataclasses.asdict(self)
additional_args.pop("hybrid_train")
additional_args["random_seed"] = seeding.get_seed()
# For simplicity, we let all DP ranks have different ports.
ports = [None for _ in range(constants.data_parallel_world_size())]
while any(port is None for port in ports) or len(set(ports)) != len(ports):
dist.all_gather_object(
ports, network.find_free_port(), group=constants.data_parallel_group()
)
additional_args["port"] = ports[constants.data_parallel_rank()]
server_args_dict = dict(
host="localhost",
# Model and tokenizer
tokenizer_path=self.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
is_embedding=False,
skip_tokenizer_init=True,
# Other runtime options
tp_size=constants.model_parallel_world_size(),
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
file_storage_pth=os.path.join(
constants.SGLANG_CACHE_PATH,
f"sglang_storage{constants.data_parallel_rank()}",
),
# Data parallelism
dp_size=1, # TODO: check whether we require SGLang dp
load_balance_method="round_robin",
# Expert parallelism
ep_size=1, # TODO: check
# Multi-node distributed serving
dist_init_addr=None,
nnodes=1,
node_rank=0,
**additional_args,
)
model.module = SGLangGenerationEngine(
server_args_dict,
hybrid_train=self.hybrid_train,
)
model.backend_name = "sglang"
return model
def load(self, model: Model, load_dir: str):
model.module.update_weights_from_disk(load_dir)
register_backend("sglang", SGLangGenerationBackend)

View File

@ -0,0 +1,249 @@
import math
from realhf.base import constants, logging
logger = logging.getLogger(__name__)
# Adopted from Megatron-LM/megatron/training/optimizer_param_scheduler.py
class OptimizerParamScheduler(object):
"""Anneals learning rate and weight decay.
Adopted from Megatron-LM. This class is not included in
megatron.core, so we have to copy-paste it here.
"""
def __init__(
self,
optimizer,
init_lr,
max_lr,
min_lr,
lr_warmup_steps,
lr_decay_steps,
lr_decay_style,
start_wd,
end_wd,
wd_incr_steps,
wd_incr_style,
use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False,
):
# Class values.
self.optimizer = optimizer
self.init_lr = init_lr
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
assert self.init_lr <= self.max_lr
self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0
self.lr_decay_steps = lr_decay_steps
assert self.lr_decay_steps > 0
assert self.lr_warmup_steps < self.lr_decay_steps
self.lr_decay_style = lr_decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, (
"both override and " "use-checkpoint are set."
)
# Set the learning rate
self.step(0)
self.log_rank_0("> learning rate decay style: {}".format(self.lr_decay_style))
def log_rank_0(self, msg):
if constants.parallelism_rank() == 0:
logger.info(msg)
def get_wd(self):
"""Weight decay incr functions."""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == "constant":
assert self.start_wd == self.end_wd
return self.end_wd
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == "linear":
coeff = incr_ratio
elif self.wd_incr_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception(
"{} weight decay increment style is not supported.".format(
self.wd_incr_style
)
)
return self.start_wd + coeff * delta_wd
def get_lr(self, param_group):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
max_lr = param_group.get("max_lr", self.max_lr)
min_lr = param_group.get("min_lr", self.min_lr)
# Use linear warmup for the initial part.
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.init_lr + (
(max_lr - self.init_lr)
* float(self.num_steps)
/ float(self.lr_warmup_steps)
)
# If the learning rate is constant, just return the initial value.
if self.lr_decay_style == "constant":
return max_lr
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
if self.num_steps > self.lr_decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
if self.lr_decay_style == "inverse-square-root":
warmup_steps = max(self.lr_warmup_steps, 1)
num_steps = max(self.num_steps, 1)
lr = max_lr * warmup_steps**0.5 / (num_steps**0.5)
return max(min_lr, lr)
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
if self.lr_decay_style == "linear":
coeff = 1.0 - decay_ratio
elif self.lr_decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception(
"{} decay style is not supported.".format(self.lr_decay_style)
)
return min_lr + coeff * delta_lr
def step(self, increment):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_wd = self.get_wd()
for param_group in self.optimizer.param_groups:
new_lr = self.get_lr(param_group)
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
def state_dict(self):
state_dict = {
"max_lr": self.max_lr,
"lr_warmup_steps": self.lr_warmup_steps,
"num_steps": self.num_steps,
"lr_decay_style": self.lr_decay_style,
"lr_decay_steps": self.lr_decay_steps,
"min_lr": self.min_lr,
"start_wd": self.start_wd,
"end_wd": self.end_wd,
"wd_incr_style": self.wd_incr_style,
"wd_incr_steps": self.wd_incr_steps,
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_opt_param_scheduler:
self.log_rank_0(" > overriding {} value to {}".format(name, cls_value))
return cls_value
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, (
f"OptimizerParamScheduler: class input value {cls_value} and checkpoint"
f"value {sd_value} for {name} do not match"
)
self.log_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
return sd_value
def load_state_dict(self, sd):
if "start_lr" in sd:
max_lr_ = sd["start_lr"]
else:
max_lr_ = sd["max_lr"]
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
self.min_lr = self._check_and_set(
self.min_lr, sd["min_lr"], "minimum learning rate"
)
if "warmup_iter" in sd:
lr_warmup_steps_ = sd["warmup_iter"]
elif "warmup_steps" in sd:
lr_warmup_steps_ = sd["warmup_steps"]
else:
lr_warmup_steps_ = sd["lr_warmup_steps"]
self.lr_warmup_steps = self._check_and_set(
self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations"
)
if "end_iter" in sd:
lr_decay_steps_ = sd["end_iter"]
elif "decay_steps" in sd:
lr_decay_steps_ = sd["decay_steps"]
else:
lr_decay_steps_ = sd["lr_decay_steps"]
self.lr_decay_steps = self._check_and_set(
self.lr_decay_steps, lr_decay_steps_, "total number of iterations"
)
if "decay_style" in sd:
lr_decay_style_ = sd["decay_style"]
else:
lr_decay_style_ = sd["lr_decay_style"]
self.lr_decay_style = self._check_and_set(
self.lr_decay_style, lr_decay_style_, "learning rate decay style"
)
if "num_iters" in sd:
num_steps = sd["num_iters"]
else:
num_steps = sd["num_steps"]
self.step(increment=num_steps)
if "start_wd" in sd:
self.start_wd = self._check_and_set(
self.start_wd, sd["start_wd"], "start weight decay"
)
self.end_wd = self._check_and_set(
self.end_wd, sd["end_wd"], "end weight decay"
)
self.wd_incr_steps = self._check_and_set(
self.wd_incr_steps,
sd["wd_incr_steps"],
"total number of weight decay iterations",
)
self.wd_incr_style = self._check_and_set(
self.wd_incr_style,
sd["wd_incr_style"],
"weight decay incr style",
)

View File

@ -136,6 +136,8 @@ def setup_global_comm(
for model_name, ranks in mw_ranks.items():
model_groups[model_name] = topology.new_or_get_group(ranks, backend=backend)
constants.set_parallelism_group(model_name, model_groups[model_name], ranks)
cpu_group = topology.new_or_get_group(ranks, backend="gloo")
constants.set_cpu_parallelism_group(model_name, cpu_group)
self_group = None
for i in range(world_size):

View File

@ -62,6 +62,9 @@ class ParamReallocInfo:
param_realloc_model_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_model_cpu_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup]
param_realloc_src_ranks: Dict[ParamReallocPair, int]
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]]
@ -270,6 +273,7 @@ def setup_param_realloc(
param_realloc_src_ranks = {}
param_realloc_dst_ranks = {}
param_realloc_model_group = {}
param_realloc_model_cpu_group = {}
if param_realloc_pairs is not None:
for src, dst in param_realloc_pairs:
_create_param_realloc_groups(
@ -296,11 +300,15 @@ def setup_param_realloc(
param_realloc_model_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)))
)
param_realloc_model_cpu_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)), backend="gloo")
)
return ParamReallocInfo(
param_realloc_groups=param_realloc_groups,
param_realloc_src_ranks=param_realloc_src_ranks,
param_realloc_dst_ranks=param_realloc_dst_ranks,
param_realloc_model_group=param_realloc_model_group,
param_realloc_model_cpu_group=param_realloc_model_cpu_group,
)

View File

@ -444,13 +444,13 @@ class PPOActorInterface(model_api.ModelInterface):
)
res = SequenceSample(
keys=["packed_ref_logprobs"],
keys=["logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=model.module.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
dtypes=dict(logprobs=model.module.dtype),
trailing_shapes=dict(logprobs=()),
data=dict(logprobs=logprobs),
seqlens=dict(
packed_ref_logprobs=[
logprobs=[
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
]
),

View File

@ -11,21 +11,17 @@ from typing import Dict, Literal, Optional, Tuple
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils.functional import (
gather_packed_shifted_log_probs,
masked_normalization,
)
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
logger = logging.getLogger("RefRwInterface")

View File

@ -122,6 +122,7 @@ class ReaLModel(nn.Module):
config: model_api.ReaLModelConfig,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
hf_model_family: Optional[str] = None,
):
super().__init__()
if dtype is None:
@ -173,6 +174,14 @@ class ReaLModel(nn.Module):
)
self.contiguous_param = None
self.hf_model_family = hf_model_family
def save_to_hf(self, tokenizer, save_dir):
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
def load_from_hf(self, load_dir):
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
@property
def pre_process(self):
# A workaround to make Megatron-LM backend happy.
@ -918,12 +927,7 @@ def make_real_model(
model_path=model_path,
is_critic=is_critic or init_critic_from_actor,
)
m = ReaLModel(mconfig, dtype=dtype, device=device)
# Since we load from `hf_model_family`, we should save to `hf_model_family`.
# The following line creates a convinent function to save the model.
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{hf_model_family}"))
setattr(ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{hf_model_family}"))
m = ReaLModel(mconfig, dtype=dtype, device=device, hf_model_family=hf_model_family)
if not init_from_scratch:
m._instantiation_hooks.append(

View File

@ -166,6 +166,27 @@ def build_leave_one_indices(
)
@torch.compile
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
):
"""Gather log probs from logits and labels.
Args:
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
each sequence is not used.
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
The first value at the beginning of each sequence has no corresponding log prob.
Returns:
torch.FloatTensor: Log probability with shape [tot_seqlen - #seqs].
"""
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
def gather_packed_shifted_log_probs(
logits: torch.FloatTensor,
cu_seqlens: torch.Tensor,
@ -174,11 +195,11 @@ def gather_packed_shifted_log_probs(
"""Gather log probs from packed input_ids and logits.
Args:
logits_ (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
each sequence is not used.
cu_seqlens (torch.Tensor): Shape [#seqs + 1]. Indices marking the start
and end of each sequences.
labels_ (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
and end of each sequence.
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
The first value at the beginning of each sequence has no corresponding log prob.
Returns:
@ -202,8 +223,7 @@ def gather_packed_shifted_log_probs(
# for i in range(cu_seqlens.shape[0] - 1)
# ])
# shift labels one step to the left and pad it to match the shape of logits
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
log_probs_labels = gather_logprobs(logits, labels)
log_probs_labels = log_probs_labels[leave_one_indices]
assert log_probs_labels.shape[0] == logits_shape[0] - cu_seqlens.shape[0] + 1, (
log_probs_labels.shape,

View File

@ -127,9 +127,10 @@ class SlurmSchedulerClient(SchedulerClient):
)
wrap_cmd = "singularity exec "
if cluster_spec.name == "na132":
wrap_cmd += "--pid --no-home --writable-tmpfs "
wrap_cmd += "--pid "
if cluster_spec.gpu_type == "tesla":
wrap_cmd += "--nv "
wrap_cmd += "--no-home --writable-tmpfs "
if len(launch_info.env_vars) > 0:
wrap_cmd += f"{' '.join([f'--env {k}={v}' for k, v in launch_info.env_vars.items()])} "
if len(launch_info.container_mounts) > 0:

View File

@ -621,6 +621,9 @@ class RayController:
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
)
runtime_env = {
"env_vars": env_vars,

View File

@ -12,7 +12,7 @@ import torch.distributed as dist
from realhf import SequenceSample
from realhf.api.core.config import ModelName, ModelShardID
from realhf.base import constants
from realhf.base import constants, logging
from realhf.base.topology import ProcessTopology, new_or_get_group
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.system.redistributor import RedistribStep
@ -21,6 +21,8 @@ BCAST_GROUPS = {}
GATHER_GROUPS = {}
SCATTER_GROUPS = {}
logger = logging.getLogger("data_manager", "system")
class DataManager:
@ -325,8 +327,21 @@ class DataManager:
)
if dist.get_rank() == step.root:
scatter_list = []
for ids in step.ids:
# Scatter destinations include all DP, TP, and PP ranks
# and data is duplicated among TP/PP groups
# We allocate new memory for DP ranks, but use the same pointer
# for all TP and PP ranks to save memory.
scatter_clusters = []
for idx, ids in enumerate(step.ids):
for _ids, idx_list in scatter_clusters:
if set(ids) == set(_ids):
idx_list.append(idx)
break
else:
scatter_clusters.append((ids, [idx]))
scatter_list = [None for _ in range(len(step.ids))]
before_pad = []
for ids, idx_list in scatter_clusters:
for i in ids:
self.storage[i].to_device(constants.current_device())
samples = [self.storage[i] for i in ids]
@ -337,11 +352,17 @@ class DataManager:
for key in step.keys
]
)
scatter_list.append(data)
maxlen = max([x.shape[0] for x in scatter_list])
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
if step.root not in step.dsts:
before_pad.append(data)
maxlen = max([x.shape[0] for x in before_pad])
after_pad = [self._pad_data(x, maxlen) for x in before_pad]
for (ids, idx_list), data in zip(scatter_clusters, after_pad):
for idx in idx_list:
scatter_list[idx] = data
assert all([torch.is_tensor(t) for t in scatter_list])
if step.root not in step.dsts:
idx = bisect.bisect(step.dsts, step.root)
scatter_list.insert(idx, buf)
else:

View File

@ -85,7 +85,11 @@ class MasterWorker(worker_base.Worker):
and config.exp_ctrl.ckpt_freq_steps is None
and config.exp_ctrl.ckpt_freq_secs is None
):
self.__ckpt_ctl = self.__save_ctl
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
else:
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,

View File

@ -24,7 +24,6 @@ import numpy as np
import pynvml
import tabulate
import torch
import torch.distributed
import torch.distributed as dist
import torch.utils.data
@ -36,6 +35,8 @@ from realhf.base import (
constants,
gpu_utils,
logging,
name_resolve,
names,
network,
recover,
seeding,
@ -375,7 +376,8 @@ class ModelWorker(worker_base.Worker):
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{self._dp_rank}.npy",
"dataset_indices",
f"{self._dp_rank}.npy",
)
if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
@ -456,8 +458,8 @@ class ModelWorker(worker_base.Worker):
self.__request_cache = {}
self.__ack_cache = {}
self.__request_queue = queue.Queue(maxsize=8)
self.__reply_queue = queue.Queue(maxsize=8)
self.__request_queue = queue.Queue(maxsize=10240)
self.__reply_queue = queue.Queue(maxsize=10240)
self.__request_sample_size = dict()
self.__compute_input_queues = {
@ -469,6 +471,13 @@ class ModelWorker(worker_base.Worker):
for model_name in self.__models.keys()
}
# By intention, must be smaller than -1.
self._last_param_realloc_step = -100
if self.__recover_run:
self._last_param_realloc_step = (
self.__recover_info.last_step_info.global_step
)
def __handle_one_rpc_hook(self, hook: str, hook_data: Any):
ret = None
@ -580,8 +589,10 @@ class ModelWorker(worker_base.Worker):
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{dp_rank}.npy",
"dataset_indices",
f"{dp_rank}.npy",
)
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
if hasattr(self.__dataset, "filter") and os.path.exists(
eval_scores_path
):
@ -625,9 +636,6 @@ class ModelWorker(worker_base.Worker):
res = data_api.DataBatchMeta(
dp_rank=dp_rank,
meta_sample=meta_sample,
is_final_batch=(
self.__dataset_batch_counter == len(self.__dataloader) - 1
),
)
elif request.handle_name == "spec":
# Raw dataset without filtering.
@ -737,6 +745,31 @@ class ModelWorker(worker_base.Worker):
assert isinstance(res, dict), res
res.update({f"eval_{k}": v for k, v in ret.items()})
# update param realloc step after handling post hooks
if request.handle_name == "train_step":
self._last_param_realloc_step = max(self._last_param_realloc_step + 1, 1)
realloc_dir = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
model_name.role,
)
save_meta = dict(
model_name=model_name,
save_backend=False,
save_dir=realloc_dir,
)
self.__save_model(save_meta)
name = names.model_version(
self.__experiment_name,
self.__trial_name,
model_name.role,
)
with constants.model_scope(model_name):
dist.barrier(group=constants.parallelism_group())
if constants.parallelism_rank() == 0:
name_resolve.add_subentry(name, str(self._last_param_realloc_step))
self.__reply_queue.put_nowait((request, res))
sample_count = data.bs if isinstance(data, data_api.SequenceSample) else 1
self.__request_sample_size[request.request_id] = sample_count
@ -761,7 +794,7 @@ class ModelWorker(worker_base.Worker):
or self.__enable_memory_dump
):
torch.cuda.synchronize()
torch.distributed.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
# pfer can be a null context if enable_profiler is False
pfer = get_pytorch_profiler(
kernel_only=False, enabled=self.__enable_profiler
@ -780,7 +813,7 @@ class ModelWorker(worker_base.Worker):
or self.__enable_memory_dump
):
pfer.__exit__(None, None, None)
torch.distributed.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
torch.cuda.synchronize()
tok = time.perf_counter()
rpc_time = tok - tik
@ -800,6 +833,17 @@ class ModelWorker(worker_base.Worker):
else:
self.__performance_recorder["time"].append(rpc_time)
with open(
os.path.join(
self._get_setup_logdir("performance"),
f"rpc-mw{self.__worker_index}.txt",
),
"a",
) as f:
f.write(
f"rpc: {rpc.name} rank: {dist.get_rank()} time: {rpc_time}\n"
)
if self.__enable_profiler:
if self._dp_rank == 0 and self._is_dp_head:
blogger.info(
@ -902,7 +946,7 @@ class ModelWorker(worker_base.Worker):
eval_scores.update(scores)
res.metadata.pop("scores")
dist.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
if len(eval_scores) > 0 and self._dp_rank == 0 and self._is_dp_head:
with open(
eval_scores_path,
@ -938,7 +982,7 @@ class ModelWorker(worker_base.Worker):
self._clear_memory()
if constants.use_cuda():
torch.cuda.synchronize()
dist.barrier(group=constants.parallelism_group())
dist.barrier(group=constants.cpu_parallelism_group())
return res
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
@ -980,7 +1024,7 @@ class ModelWorker(worker_base.Worker):
with constants.model_scope(from_model_name):
from_model_ranks = constants.parallelism_group_ranks()
if not param_realloc_comm.is_trainable(from_model_name):
if torch.distributed.get_rank() not in from_model_ranks:
if dist.get_rank() not in from_model_ranks:
return
if not isinstance(self.__unwrapped_models[from_model_name], ReaLModel):
# We can only release the memory of ReaLModel,
@ -1006,7 +1050,7 @@ class ModelWorker(worker_base.Worker):
save_dir=realloc_dir,
)
self.__save_model(save_meta)
g = self.__param_realloc_info.param_realloc_model_group[
g = self.__param_realloc_info.param_realloc_model_cpu_group[
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
]
dist.barrier(group=g)
@ -1018,7 +1062,7 @@ class ModelWorker(worker_base.Worker):
self.__load_model(load_meta)
# Remove the reallocated checkpoint.
with constants.model_scope(to_model_name):
dist.barrier(constants.parallelism_group())
dist.barrier(constants.cpu_parallelism_group())
if constants.parallelism_rank() == 0:
shutil.rmtree(realloc_dir, ignore_errors=True)
os.makedirs(realloc_dir, exist_ok=True)
@ -1086,7 +1130,7 @@ class ModelWorker(worker_base.Worker):
).is_symlink():
os.unlink(save_root / fn)
shutil.rmtree(save_dir, ignore_errors=True)
dist.barrier(constants.parallelism_group())
dist.barrier(constants.cpu_parallelism_group())
self._interface.save(self._model, save_dir)
# The `save` method of the interface may be empty.
# We only save the backend state if the parameters have been indeed saved.
@ -1153,11 +1197,12 @@ class ModelWorker(worker_base.Worker):
@cuda_tmark("post_response", CUDATimeMarkType.misc)
def maybe_post_responses(self):
ready_to_post = []
try:
request, res = self.__reply_queue.get_nowait()
ready_to_post.append((request, res))
except queue.Empty:
pass
while True:
try:
request, res = self.__reply_queue.get_nowait()
ready_to_post.append((request, res))
except queue.Empty:
break
batch_size = sample_size = 0
for request, res in ready_to_post:

View File

@ -3,7 +3,6 @@
import asyncio
import dataclasses
import itertools
import os
from collections import defaultdict
from typing import *

View File

@ -52,7 +52,7 @@ class Payload:
syn_reply_id: uuid.UUID = None
ack_reply_id: uuid.UUID = None
no_syn: bool = False
no_syn: bool = True
send_time: float = None
@ -164,7 +164,7 @@ class NameResolvingRequestClient:
datas: List[Any] | None = None,
payloads: List[Payload] | None = None,
verbose: bool = True,
no_syn: bool = False,
no_syn: bool = True,
) -> List[uuid.UUID]:
"""Send requests of type `handle_type` to all `handlers` with
corresponding `data`.

View File

@ -1,7 +1,7 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import enum
import os
@ -547,6 +547,18 @@ class Worker:
"""Implemented by sub-classes."""
raise NotImplementedError()
@property
def running(self):
return self.__running
@property
def exiting(self):
return self.__exiting
@property
def is_configured(self):
return self.__is_configured
def configure(
self,
worker_info: system_api.WorkerInformation,
@ -625,10 +637,11 @@ class Worker:
def _exit_hook(self, exit_status: WorkerServerStatus):
logger.warning(f"Exit with {exit_status}, hook not implemented, pass.")
def exit(self):
def exit(self, err: bool = False):
self.logger.info("Exiting worker")
self._exit_hook(WorkerServerStatus.COMPLETED)
self.__set_status(WorkerServerStatus.COMPLETED)
status = WorkerServerStatus.ERROR if err else WorkerServerStatus.COMPLETED
self._exit_hook(status)
self.__set_status(status)
self.__exiting = True
def interrupt(self):
@ -681,8 +694,7 @@ class Worker:
logger.error(f"Worker encountered error {e}", exc_info=True)
if isinstance(e, WorkerException):
raise e
self.__set_status(WorkerServerStatus.ERROR)
self._exit_hook(WorkerServerStatus.ERROR)
self.exit(err=True)
raise e
def __host_key(self, key: str):
@ -694,6 +706,32 @@ class Worker:
name_resolve.watch_names(keys, call_back=self.exit)
class AsyncWorker(Worker):
async def _poll_async(self) -> PollResult:
raise NotImplementedError()
async def run_async(self):
self.logger.debug("Running worker now")
try:
while not self.exiting:
await asyncio.sleep(0.0)
self._server.handle_requests()
if not self.running:
await asyncio.sleep(0.05)
continue
if not self.is_configured:
raise RuntimeError("Worker is not configured")
r = await self._poll_async()
except KeyboardInterrupt:
self.exit()
except Exception as e:
logger.error(f"Worker encountered error {e}", exc_info=True)
if isinstance(e, WorkerException):
raise e
self.exit(err=True)
raise e
class MappingThread:
"""Wrapped of a mapping thread.

View File

@ -17,11 +17,7 @@ import torch.distributed as dist
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, testing, topology
from realhf.base.testing import (
LocalMultiProcessTest,
PipeDataModelParallelTopology,
init_global_constants,
)
from realhf.base.testing import LocalMultiProcessTest, init_global_constants
from realhf.system.data_manager import DataManager
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
@ -143,7 +139,7 @@ def _test_data_transfer(
):
from_model_name = ModelName("data_transfer_test", 0)
from_topo = PipeDataModelParallelTopology(
from_topo = topology.PipeDataModelParallelTopology(
num_pp=from_pp_dp_mp[0],
num_mp=from_pp_dp_mp[-1],
num_dp=from_pp_dp_mp[1],
@ -152,7 +148,7 @@ def _test_data_transfer(
gradient_accumulation_fusion=True,
)
to_model_name = ModelName("data_transfer_test", 1)
to_topo = PipeDataModelParallelTopology(
to_topo = topology.PipeDataModelParallelTopology(
num_pp=to_pp_dp_mp[0],
num_mp=to_pp_dp_mp[-1],
num_dp=to_pp_dp_mp[1],

View File

@ -1,8 +1,8 @@
import functools
import gc
import json
import os
import pickle
import json
import time
from typing import *
@ -26,9 +26,9 @@ from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.network import find_free_port
from realhf.base.testing import (
init_global_constants,
_DEFAULT_EXPR_NAME,
_DEFAULT_TRIAL_NAME,
init_global_constants,
)
logger = logging.getLogger("test async ref-rew")

View File

@ -0,0 +1,242 @@
# Copyright 2025 Ant Group Inc.
import functools
import multiprocessing as mp
import os
import uuid
import numpy as np
import torch
import torch.distributed as dist
import transformers
from realhf.api.core import data_api, model_api
from realhf.api.core.config import ModelName
from realhf.api.core.data_api import MicroBatchSpec
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.testing import init_global_constants
logger = logging.getLogger("test sglang backend")
def check_sequences_consistency(
batched_seq1: torch.LongTensor, batched_seq2: torch.LongTensor
):
matched_tokens = 0
matched_seqs = 0
total_tokens = 0
assert len(batched_seq1) == len(batched_seq2)
for i in range(len(batched_seq1)):
a = batched_seq1[i]
b = batched_seq2[i]
assert torch.is_tensor(a) and torch.is_tensor(b)
assert a.dim() == 1 and b.dim() == 1, (a.shape, b.shape)
gen_len = a.shape[0] if a.shape[0] < b.shape[0] else b.shape[0]
b = b[:gen_len]
a = a[:gen_len]
for j in range(gen_len):
if a[j] != b[j]:
logger.info(f"Mismatch at sequence {i} position {j}")
break
matched_tokens += 1
else:
matched_seqs += 1
total_tokens += gen_len
logger.info(
f"Matched {matched_seqs}/{len(batched_seq1)} "
f"sequences and {matched_tokens}/{total_tokens} tokens"
)
return (
matched_seqs,
matched_tokens,
float(matched_tokens) / total_tokens,
float(matched_seqs) / len(batched_seq1),
)
def test_fn(
rank: int,
world_size: int,
path: str,
model_family_name: str,
dp: int,
pp: int,
tp: int,
):
assert not torch.cuda.is_initialized()
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(0)
assert world_size == (
dp * pp * tp
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
# Initialize distributed environment.
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
init_method="tcp://localhost:7777",
)
torch.cuda.set_device(0)
model_name = ModelName("default", 0)
constants.set_experiment_trial_names("slang-test", str(uuid.uuid4()))
init_global_constants(
num_dp=dp,
num_mp=tp,
num_pp=pp,
sequence_parallel=False,
model_name=model_name,
max_prompt_len=128,
)
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
transformers.AutoConfig.from_pretrained(
path,
trust_remote_code=True,
force_download=True,
)
)
with constants.model_scope(model_name):
module = ReaLModel(mconfig, dtype=torch.float16, device="cuda")
module._instantiation_hooks.append(
lambda: getattr(module, f"from_{model_family_name}")(
load_dir=path, init_critic_from_actor=False
)
)
add_helper_functions(module)
module.instantiate()
module.eval()
tokenizer = data_api.load_hf_tokenizer(path)
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
backend = SGLangGenerationBackend(model_path=path)
model = model_api.Model(
name=model_name,
module=module,
tokenizer=tokenizer,
device=module.device,
dtype=module.dtype,
)
ft_spec = model_api.FinetuneSpec(
total_train_epochs=1,
dataset_size=100,
train_batch_size=1,
)
model = backend.initialize(model, ft_spec)
gconfig = model_api.GenerationHyperparameters(
n=1,
max_new_tokens=32,
min_new_tokens=0,
greedy=True,
top_p=1.0,
top_k=int(1e8),
temperature=1.0,
use_cuda_graph=False,
)
bs = 8
for i in range(1):
seqlens = [torch.randint(5, 10, (1,)).cuda() for _ in range(bs)]
for s in seqlens:
dist.broadcast(s, src=0)
seqlens = [int(s) for s in seqlens]
token_ids = (
torch.randint(0, mconfig.vocab_size, (sum(seqlens),)).long().cuda()
)
dist.broadcast(token_ids, src=0)
max_seqlen = max(seqlens)
cu_seqlens = torch.nn.functional.pad(
torch.tensor(seqlens, device="cuda").cumsum(0),
(1, 0),
).int()
res = module.generate(
tokenizer=tokenizer,
packed_input_ids=token_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
gconfig=gconfig,
)
gen_tokens1 = res.sequences
logprobs1 = res.scores
x = data_api.SequenceSample.from_default(
seqlens=seqlens,
ids=list(range(bs)),
data=dict(packed_input_ids=token_ids),
)
gen_tokens2, logprobs2, _ = model.module.generate(
input_=x,
mb_spec=MicroBatchSpec(),
tokenizer=tokenizer,
gconfig=gconfig,
)
if constants.model_parallel_rank() == 0:
# The outputs are Nones for tp_rank > 1 in SGLang
_, _, token_match_percent, seq_match_percent = (
check_sequences_consistency(gen_tokens1, gen_tokens2)
)
assert token_match_percent > 0.8, token_match_percent
assert seq_match_percent > 0.8, seq_match_percent
print("success")
# 清理分布式环境
dist.destroy_process_group()
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
mp.set_start_method("spawn", force=True)
world_size = dp * tp
procs = [
mp.Process(
target=test_fn,
args=(
i,
world_size,
),
kwargs=dict(
path=path,
model_family_name=model_family_name,
dp=dp,
pp=1,
tp=tp,
),
)
for i in range(world_size)
]
try:
for p in procs:
p.start()
for p in procs:
p.join()
except KeyboardInterrupt:
[p.terminate() for p in procs]
[p.join() for p in procs]
if __name__ == "__main__":
path = "/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
model_family_name = "qwen2"
# test_fn(
# rank=0,
# world_size=1,
# path=path,
# model_family_name=model_family_name,
# dp=1,
# pp=1,
# tp=1,
# )
test_sglang_consistency(
tp=2,
dp=2,
path=path,
model_family_name=model_family_name,
)