mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
commit
f8586e47c8
|
@ -13,7 +13,6 @@ from functioncall.base import logging
|
|||
|
||||
logger = logging.getLogger("Functioncall")
|
||||
|
||||
|
||||
FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
|
||||
"FUNCTIONCALL_SERVICE_DOMAIN",
|
||||
"",
|
||||
|
@ -31,8 +30,8 @@ async def async_invoke_function(
|
|||
function_name: str,
|
||||
timeout: aiohttp.ClientTimeout,
|
||||
payload: Dict[str, Any] = None,
|
||||
max_retries: int = 3,
|
||||
initial_retry_interval: float = 0.1,
|
||||
max_retries: int = 100,
|
||||
initial_retry_interval: float = 0.5,
|
||||
max_retry_interval: float = 10.0,
|
||||
):
|
||||
if payload is None:
|
||||
|
@ -63,7 +62,7 @@ async def async_invoke_function(
|
|||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.warning(
|
||||
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}, payload: {payload}"
|
||||
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}"
|
||||
)
|
||||
break
|
||||
|
||||
|
@ -85,7 +84,7 @@ async def async_invoke_function(
|
|||
|
||||
|
||||
async def batch_function_call_async(
|
||||
payload_list, function_name, timeout, concurrency=1000
|
||||
payload_list, function_name, timeout, concurrency=1500
|
||||
):
|
||||
connector = aiohttp.TCPConnector(limit=0)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
|
@ -120,7 +119,7 @@ async def batch_function_call_async(
|
|||
p90 = calculate_percentile(elapsed_times, 90)
|
||||
p99 = calculate_percentile(elapsed_times, 99)
|
||||
logger.info(
|
||||
f"Longest functioncall took {max_elapsed:.4f} seconds, header: {max_elapsed_header}, p50: {p50}, p90: {p90}, p99: {p99}"
|
||||
f"Longest functioncall {function_name} took {max_elapsed:.4f} seconds, header: {max_elapsed_header}, timeout: {timeout}, p50: {p50}, p90: {p90}, p99: {p99}"
|
||||
)
|
||||
|
||||
return data_list
|
||||
|
@ -129,19 +128,21 @@ async def batch_function_call_async(
|
|||
def get_function_name(runtime_type):
|
||||
if runtime_type == "python_code":
|
||||
return "realhf_code_verify"
|
||||
if runtime_type == "python_live_code_bench":
|
||||
return "python_live_code_bench"
|
||||
elif runtime_type == "python_math":
|
||||
return "realhf_math_verify"
|
||||
return "python_math"
|
||||
return "empty_code"
|
||||
|
||||
|
||||
def batch_function_call(payload_list, runtime_type, timeout=30):
|
||||
def batch_function_call(payload_list, runtime_type, timeout):
|
||||
start_time = time.time()
|
||||
function_name = get_function_name(runtime_type)
|
||||
result = asyncio.run(
|
||||
batch_function_call_async(payload_list, function_name, timeout)
|
||||
)
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Batch function call done, runtime type: {runtime_type}, batch size: {len(payload_list)}, cost: {execution_time * 1000:.0f} ms"
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -714,7 +714,7 @@ def reliability_guard(maximum_memory_bytes=None):
|
|||
|
||||
import builtins
|
||||
|
||||
builtins.exit = None
|
||||
# builtins.exit = None
|
||||
builtins.quit = None
|
||||
|
||||
import os
|
||||
|
|
|
@ -29,7 +29,7 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
|
|||
|
||||
# parse one problem
|
||||
row = json.loads(line.strip().decode("utf-8"))
|
||||
query_id = str(row["id"])
|
||||
query_id = str(row.get("id", row.get("query_id")))
|
||||
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
|
||||
inputs = input_output.get("inputs", [])
|
||||
outputs = input_output.get("outputs", [])
|
||||
|
@ -66,7 +66,9 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
|
|||
global_problems = None
|
||||
|
||||
|
||||
def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_testcase=6):
|
||||
def code_verify(
|
||||
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
|
||||
):
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
len(query_ids),
|
||||
|
@ -116,14 +118,11 @@ def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_test
|
|||
value = 0
|
||||
if rsp and "result" in rsp and not any(x != True for x in rsp["result"]):
|
||||
value = 1
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
f"Functioncall code verify failed, query index: {query_index}, query id: {query_id}, results: {rsp}"
|
||||
f"Functioncall code verify not passed, query index: {query_index}, query id: {query_id}, results: {rsp}"
|
||||
)
|
||||
|
||||
# logger.debug(f"query index: {idx}, value: {value}, results[query_index]: {results[query_index]}")
|
||||
|
||||
results[query_index] = results[query_index] and value
|
||||
|
||||
return results
|
||||
|
@ -133,6 +132,15 @@ if __name__ == "__main__":
|
|||
|
||||
def create_test_params(count=10):
|
||||
global global_problems
|
||||
if global_problems is None:
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
),
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
codes, query_ids = [], []
|
||||
idx = 0
|
||||
for query_id, problems in global_problems.items():
|
||||
|
@ -149,6 +157,6 @@ if __name__ == "__main__":
|
|||
|
||||
return codes, query_ids
|
||||
|
||||
codes, query_ids = create_test_params(1000)
|
||||
result = code_verify(codes, query_ids, True, 100)
|
||||
codes, query_ids = create_test_params(100)
|
||||
result = code_verify(codes, query_ids, True)
|
||||
print(result)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import time
|
||||
from parser import extract_answer
|
||||
|
||||
from grader import math_equal
|
||||
|
@ -6,11 +7,8 @@ from grader import math_equal
|
|||
|
||||
def process_results(answer, solution):
|
||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||
extracted_solution = solution
|
||||
|
||||
print(
|
||||
f"extracted_answer: {extracted_answer}, extracted_solution: {extracted_solution}, equal: {math_equal(extracted_answer, extracted_solution)}"
|
||||
)
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
elif math_equal(extracted_answer, extracted_solution, timeout=True):
|
||||
|
@ -24,16 +22,19 @@ def process_results(answer, solution):
|
|||
def handle(event, context):
|
||||
answers = event.get("answers", "")
|
||||
solutions = event.get("solutions", "")
|
||||
|
||||
# print(f"math payload:{event}\n")
|
||||
# answers and solutions are json lists, and call process_results then collect result into a list
|
||||
if isinstance(answers, str):
|
||||
answers = json.loads(answers)
|
||||
if isinstance(solutions, str):
|
||||
solutions = json.loads(solutions)
|
||||
query_ids = event.get("query_ids", "")
|
||||
|
||||
results = []
|
||||
for answer, solution in zip(answers, solutions):
|
||||
results.append(process_results(answer, solution))
|
||||
for answer, solution, query_id in zip(
|
||||
answers,
|
||||
solutions,
|
||||
query_ids,
|
||||
):
|
||||
start_time = time.time()
|
||||
result = process_results(answer, solution)
|
||||
results.append(result)
|
||||
print(
|
||||
f"query_id: {query_id}, result: {result}, current cost: {(time.time() - start_time) * 1000:.0f} ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
@ -22,7 +22,7 @@ def loadJson(dataDir):
|
|||
id2info = None
|
||||
|
||||
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) -> List:
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
|
||||
global id2info
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
|
@ -43,7 +43,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
|
|||
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
|
||||
base_query_id = query_id.split("@idx:")[0]
|
||||
info = id2info[base_query_id]
|
||||
for cur_solution in info["solutions"]:
|
||||
for cur_solution in info["answers"]:
|
||||
parameters.append((generated, cur_solution, idx))
|
||||
query_indices.append(idx)
|
||||
|
||||
|
@ -58,7 +58,6 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
|
|||
"query_ids": [query_ids[i] for i in indices],
|
||||
}
|
||||
|
||||
# print(batch_args)
|
||||
batch_args_list.append(batch_args)
|
||||
|
||||
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
|
||||
|
@ -67,15 +66,15 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
|
|||
# Map results back to original indices
|
||||
index = 0
|
||||
for batch_idx, results in enumerate(results_batch):
|
||||
query_index = query_indices[index]
|
||||
if not isinstance(results, list) or len(results) == 0:
|
||||
index += len(batch_args_list[batch_idx]["answers"])
|
||||
logger.warning(
|
||||
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_index}, params: {batch_args_list[batch_idx]['answers']}."
|
||||
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_indices[index]}, params: {batch_args_list[batch_idx]['answers']}."
|
||||
)
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
query_index = query_indices[index]
|
||||
if (
|
||||
isinstance(result, list)
|
||||
and len(result) > 0
|
||||
|
@ -90,23 +89,33 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
|
|||
index += 1
|
||||
|
||||
logger.info(
|
||||
f"verify math with query size={len(query_ids)}, takes {time.time() - start_time:.4f} seconds, result: {labels}"
|
||||
f"verify math with query size={len(query_ids)}, takes {time.time() - start_time:.4f} seconds"
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sample = {
|
||||
"prompt": "",
|
||||
"query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
|
||||
"answer": "\\boxed{-\\frac{2}{3}}",
|
||||
}
|
||||
start_time = time.time()
|
||||
batch_size = 10
|
||||
result = math_verify(
|
||||
[sample["answer"]] * batch_size, [sample["query_id"] for _ in range(batch_size)]
|
||||
)
|
||||
# sample = {
|
||||
# "prompt": "",
|
||||
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
|
||||
# "answer": "\\boxed{-\\frac{1}{30}}",
|
||||
# }
|
||||
|
||||
hint = f"batch_size: {batch_size}, total time : {(time.time() - start_time) * 1000:.0f} ms"
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
os.getenv(
|
||||
"REAL_MATH_MEATADATA_PATH",
|
||||
"/storage/datasets/id2info.json",
|
||||
)
|
||||
)
|
||||
|
||||
answers = []
|
||||
query_ids = []
|
||||
|
||||
for id, value in id2info.items():
|
||||
answers.append(value["solutions"][0])
|
||||
query_ids.append(id)
|
||||
|
||||
start_time = time.time()
|
||||
result = math_verify(answers[:200], query_ids[:200])
|
||||
print(result)
|
||||
print(hint)
|
||||
|
|
|
@ -684,7 +684,6 @@ class SequenceSample:
|
|||
class DataBatchMeta:
|
||||
dp_rank: int
|
||||
meta_sample: SequenceSample | None
|
||||
is_final_batch: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import keyword
|
||||
from typing import *
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
@ -106,6 +108,139 @@ class GenerationHyperparameters:
|
|||
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
args = dataclasses.asdict(self)
|
||||
args.update(kwargs)
|
||||
return GenerationHyperparameters(**args)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class APIGenerateInput:
|
||||
qid: Hashable
|
||||
group_idx: int
|
||||
prompt_ids: List[int]
|
||||
input_ids: List[int]
|
||||
gconfig: GenerationHyperparameters
|
||||
stop_token_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
return_logprob: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class APIGenerateOutput:
|
||||
qid: Hashable
|
||||
group_idx: int
|
||||
prompt_ids: List[int]
|
||||
input_ids: List[int]
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
output_logprobs: List[int] = dataclasses.field(default_factory=list)
|
||||
no_eos: bool = True
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of inter-token latencies
|
||||
error: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_input(cls, inp: APIGenerateInput):
|
||||
return cls(
|
||||
qid=inp.qid,
|
||||
group_idx=inp.group_idx,
|
||||
prompt_ids=inp.prompt_ids,
|
||||
input_ids=inp.input_ids,
|
||||
)
|
||||
|
||||
@property
|
||||
def output_len(self):
|
||||
return len(self.output_ids)
|
||||
|
||||
@property
|
||||
def input_len(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
@property
|
||||
def prompt_len(self):
|
||||
return len(self.prompt_ids)
|
||||
|
||||
@property
|
||||
def gen_len(self):
|
||||
return self.output_len + self.input_len - self.prompt_len
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BundledGenerationOutputs:
|
||||
qid: Hashable
|
||||
prompt_ids: List[int]
|
||||
seqs: List[List[int]]
|
||||
no_eos: List[bool]
|
||||
|
||||
@classmethod
|
||||
def from_single(cls, outputs: List[APIGenerateOutput]):
|
||||
assert len(set(o.qid for o in outputs)) == 1
|
||||
return cls(
|
||||
qid=outputs[0].qid,
|
||||
prompt_ids=outputs[0].prompt_ids,
|
||||
seqs=[o.input_ids + o.output_ids for o in outputs],
|
||||
no_eos=[o.no_eos for o in outputs],
|
||||
)
|
||||
|
||||
@property
|
||||
def seqlens(self):
|
||||
return [len(seq) for seq in self.seqs]
|
||||
|
||||
@property
|
||||
def prompt_len(self):
|
||||
return len(self.prompt_ids)
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class LLMAPIClient:
|
||||
def __init__(
|
||||
self, generate_url: str, update_weights_url: str, concurrency_limit: int = -1
|
||||
):
|
||||
self.update_weights_url = update_weights_url
|
||||
self.generate_url = generate_url
|
||||
self.concurrency_limit = concurrency_limit
|
||||
|
||||
self.session: aiohttp.ClientSession
|
||||
self.semaphore: asyncio.Semaphore
|
||||
|
||||
async def __aenter__(self):
|
||||
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=conn,
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
)
|
||||
if self.concurrency_limit > 0:
|
||||
self.semaphore = asyncio.Semaphore(self.concurrency_limit)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def async_add_generate_request(
|
||||
self, req: APIGenerateInput, stream: bool = True
|
||||
) -> APIGenerateOutput:
|
||||
|
||||
if self.concurrency_limit > 0:
|
||||
async with self.semaphore:
|
||||
return await self._do_generate(req, stream=stream)
|
||||
else:
|
||||
return await self._do_generate(req, stream=stream)
|
||||
|
||||
async def _do_generate(
|
||||
self, req: APIGenerateInput, stream: bool = True
|
||||
) -> APIGenerateOutput:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_update_weights_from_disk(self, path):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReaLMoEConfig:
|
||||
|
|
|
@ -131,6 +131,65 @@ class vLLMConfig:
|
|||
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangConfig:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_jump_forward: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
cuda_graph_bs: Optional[List[int]] = None
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
triton_attention_num_kv_splits: int = 8
|
||||
num_continuous_decode_steps: int = 1
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
return_hidden_states: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "triton"
|
||||
sampling_backend: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
hybrid_train: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DistributedDataParallelConfig:
|
||||
"""Configuration for Megatron DistributedDataParallel.
|
||||
Some default options have been overwritten.
|
||||
"""
|
||||
|
||||
grad_reduce_in_fp32: bool = False
|
||||
overlap_grad_reduce: bool = True
|
||||
overlap_param_gather: bool = False
|
||||
align_param_gather: bool = False
|
||||
use_distributed_optimizer: bool = True
|
||||
check_for_nan_in_grad: bool = False
|
||||
bucket_size: Optional[int] = None
|
||||
average_in_collective: bool = False
|
||||
fp8_param_gather: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MegatronConfig:
|
||||
"""When using the DistributedOptimizer of Megatron, parameters and
|
||||
|
@ -177,12 +236,11 @@ class MegatronConfig:
|
|||
make it functionally correct. The DeepSpeed code is too hard to read and modify.
|
||||
"""
|
||||
|
||||
overlap_grad_reduce: bool = True
|
||||
overlap_param_gather: bool = False
|
||||
accumulate_allreduce_grads_in_fp32: bool = False
|
||||
|
||||
# addtional args
|
||||
additional_config: Dict = dataclasses.field(default_factory=dict)
|
||||
ddp: DistributedDataParallelConfig = dataclasses.field(
|
||||
default_factory=DistributedDataParallelConfig
|
||||
)
|
||||
# Don't use MegatronOptimizerConfig here because OmegaConf
|
||||
# does not recognize the annotation "torch.dtype"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -233,6 +291,7 @@ class ModelTrainEvalConfig:
|
|||
)
|
||||
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
|
||||
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
|
||||
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
|
||||
init_from_scratch: bool = False
|
||||
init_critic_from_actor: bool = False
|
||||
|
||||
|
|
|
@ -7,10 +7,13 @@ import functools
|
|||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
uvloop.install()
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
pass
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -60,7 +63,7 @@ def main_worker(args):
|
|||
worker_index_start + args.wprocs_per_jobstep,
|
||||
args.wprocs_in_job + args.wproc_offset,
|
||||
)
|
||||
if args.worker_type != "model_worker":
|
||||
if args.worker_type == "master_worker":
|
||||
try:
|
||||
# CUDA_VISIBLE_DEVICES is set by slurm on PPU nodes
|
||||
# we need to remove it on CPU workers
|
||||
|
|
|
@ -80,6 +80,7 @@ TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
|||
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets"
|
||||
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler"
|
||||
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc"
|
||||
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang"
|
||||
TORCH_EXTENSIONS_DIR = (
|
||||
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
|
||||
)
|
||||
|
@ -165,6 +166,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
|
|||
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
|
||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
|
||||
|
||||
# _model_name will be changed in the model_scope context manager
|
||||
_model_name: "ModelName" = None
|
||||
|
@ -177,6 +179,9 @@ _grids: Dict["ModelName", "ParallelGrid"] = {}
|
|||
_pgroups: Dict["ModelName", Any] = (
|
||||
{}
|
||||
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
||||
_cpu_pgroups: Dict["ModelName", Any] = (
|
||||
{}
|
||||
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
||||
_pgroup_ranks: Dict["ModelName", List[int]] = {}
|
||||
_self_group = None
|
||||
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
|
||||
|
@ -259,6 +264,13 @@ def set_parallelism_group(model_name: "ModelName", pgroup, ranks):
|
|||
_pgroup_ranks[model_name] = ranks
|
||||
|
||||
|
||||
def set_cpu_parallelism_group(model_name: "ModelName", pgroup):
|
||||
global _cpu_pgroups
|
||||
if model_name in _cpu_pgroups:
|
||||
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
|
||||
_cpu_pgroups[model_name] = pgroup
|
||||
|
||||
|
||||
def set_self_group(pgroup):
|
||||
global _self_group
|
||||
if _self_group is not None:
|
||||
|
@ -382,6 +394,15 @@ def parallelism_group():
|
|||
return _pgroups[_model_name]
|
||||
|
||||
|
||||
def cpu_parallelism_group():
|
||||
"""Returns the GLOO 3D parallelism group of a specific model."""
|
||||
if _model_name is None:
|
||||
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
||||
if _cpu_pgroups.get(_model_name, None) is None:
|
||||
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
|
||||
return _cpu_pgroups[_model_name]
|
||||
|
||||
|
||||
def parallelism_group_ranks():
|
||||
if _model_name is None:
|
||||
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
||||
|
@ -451,6 +472,10 @@ def prev_pipe_stage():
|
|||
) % pipe_parallel_world_size()
|
||||
|
||||
|
||||
def is_dp_head():
|
||||
return is_last_pipe_stage() and model_parallel_rank() == 0
|
||||
|
||||
|
||||
def model_parallel_rank() -> int:
|
||||
"""Return the rank inside the tensor parallelism group."""
|
||||
try:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import heapq
|
||||
import bisect
|
||||
import itertools
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
@ -159,8 +159,9 @@ def _ffd_allocate(
|
|||
as possible.
|
||||
|
||||
1. Sort the numbers in reverse order.
|
||||
2. If the number of groups is less than, create a new group.
|
||||
3. If the new number fits into the smallest group, add it into the group.
|
||||
2. If the number of groups is less than `min_groups`, create a new group.
|
||||
3. For a new number, find all groups with the capacity to hold the new number.
|
||||
Put the new number into the group with the smallest size.
|
||||
4. Otherwise, create a new group.
|
||||
"""
|
||||
value_indices = np.argsort(-values)
|
||||
|
@ -172,12 +173,17 @@ def _ffd_allocate(
|
|||
len(group_values) < min_groups
|
||||
or group_values[0][0] + values[idx] > capacity
|
||||
):
|
||||
heapq.heappush(group_values, (float(values[idx]), group_cnt))
|
||||
bisect.insort(group_values, (float(values[idx]), group_cnt))
|
||||
group_indices.append([idx])
|
||||
group_cnt += 1
|
||||
else:
|
||||
v, group_idx = heapq.heappop(group_values)
|
||||
heapq.heappush(group_values, (float(v + values[idx]), group_idx))
|
||||
i = bisect.bisect_right(group_values, (capacity - values[idx], len(values)))
|
||||
candidates = [group_values[j][1] for j in range(i)]
|
||||
lens = [len(group_indices[g]) for g in candidates]
|
||||
j = np.argmin(lens)
|
||||
v, group_idx = group_values.pop(j)
|
||||
assert group_idx == candidates[j]
|
||||
bisect.insort(group_values, (float(values[idx] + v), group_idx))
|
||||
group_indices[group_idx].append(idx)
|
||||
return group_indices
|
||||
|
||||
|
|
|
@ -268,6 +268,23 @@ class MemoryNameRecordRepository(NameRecordRepository):
|
|||
|
||||
class NfsNameRecordRepository(NameRecordRepository):
|
||||
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
|
||||
LOCK_FILE = f"{cluster_spec.fileroot}/name_resolve/LOCK"
|
||||
os.makedirs(RECORD_ROOT, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def locked(fn: Callable) -> Callable:
|
||||
def fn_(*args, **kwargs):
|
||||
import fcntl
|
||||
|
||||
with open(NfsNameRecordRepository.LOCK_FILE, "w") as fd:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
try:
|
||||
res = fn(*args, **kwargs)
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
return res
|
||||
|
||||
return fn_
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__to_delete = set()
|
||||
|
@ -280,6 +297,7 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
def __file_path(name):
|
||||
return os.path.join(NfsNameRecordRepository.__dir_path(name), "ENTRY")
|
||||
|
||||
@locked
|
||||
def add(
|
||||
self,
|
||||
name,
|
||||
|
@ -299,6 +317,7 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
if delete_on_exit:
|
||||
self.__to_delete.add(name)
|
||||
|
||||
@locked
|
||||
def delete(self, name):
|
||||
path = self.__file_path(name)
|
||||
if not os.path.isfile(path):
|
||||
|
@ -322,12 +341,22 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
else:
|
||||
logger.info("No such name resolve path: %s", dir_path)
|
||||
|
||||
@locked
|
||||
def get(self, name):
|
||||
path = self.__file_path(name)
|
||||
if not os.path.isfile(path):
|
||||
raise NameEntryNotFoundError(path)
|
||||
with open(path, "r") as f:
|
||||
return f.read().strip()
|
||||
for _ in range(100):
|
||||
# HACK: dealing with the possible OSError: Stale file handle
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
return f.read().strip()
|
||||
except OSError as e:
|
||||
if e.errno == 116:
|
||||
time.sleep(5e-3)
|
||||
continue
|
||||
raise e
|
||||
raise RuntimeError("Failed to read value for %s" % name)
|
||||
|
||||
def get_subtree(self, name_root):
|
||||
dir_path = self.__dir_path(name_root)
|
||||
|
|
|
@ -60,3 +60,7 @@ def distributed_local_peer(experiment_name, trial_name, host_name, model_name):
|
|||
|
||||
def distributed_master(experiment_name, trial_name, model_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/distributed/master/{model_name}"
|
||||
|
||||
|
||||
def model_version(experiment_name, trial_name, model_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/model_version/{model_name}"
|
||||
|
|
|
@ -6,12 +6,27 @@ import socket
|
|||
from contextlib import closing
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""From stackoverflow Issue 1365265."""
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
def find_free_port(low=1, high=65536, exclude_ports=None):
|
||||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high and port not in exclude_ports:
|
||||
return port
|
||||
|
||||
|
||||
def find_multiple_free_ports(count, low=1, high=65536):
|
||||
"""Find multiple mutually exclusive free ports."""
|
||||
free_ports = set()
|
||||
for _ in range(count):
|
||||
port = find_free_port(low, high, exclude_ports=free_ports)
|
||||
free_ports.add(port)
|
||||
return list(free_ports)
|
||||
|
||||
|
||||
def gethostname():
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from typing import List
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
from packaging.version import Version
|
||||
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -27,7 +29,22 @@ def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation
|
|||
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
|
||||
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
|
||||
if vllm.hybrid_train and not vllm.enforce_eager:
|
||||
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
|
||||
logger.warning(
|
||||
"For version < 0.7.0, vLLM hybrid_train requires eager mode to be enabled. "
|
||||
"The user has the responsibility to ensure the version is correct."
|
||||
)
|
||||
|
||||
|
||||
def check_valid_sglang(
|
||||
role: str, sglang: SGLangConfig, rpc_allocs: List[RPCAllocation]
|
||||
):
|
||||
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
|
||||
if sglang.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
|
||||
logger.warning(
|
||||
"SGLang hybrid_train is enabled, but no training RPCs are found."
|
||||
)
|
||||
if sglang.hybrid_train and not sglang.disable_cuda_graph:
|
||||
raise ValueError("SGLang hybrid_train requires CUDA graph to be disabled.")
|
||||
|
||||
|
||||
def check_valid_optimizer(model: ModelTrainEvalConfig):
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
|
@ -51,6 +52,7 @@ from realhf.experiments.common.check import (
|
|||
check_valid_model_and_path,
|
||||
check_valid_optimizer,
|
||||
check_valid_parallel_batch_size,
|
||||
check_valid_sglang,
|
||||
check_valid_vllm,
|
||||
)
|
||||
from realhf.experiments.common.utils import (
|
||||
|
@ -69,7 +71,7 @@ import realhf.api.from_hf # isort:skip
|
|||
|
||||
logger = logging.getLogger("CommonExperimentConfig", "colored")
|
||||
|
||||
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
|
||||
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -108,7 +110,7 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
|
||||
|
||||
- A regex pattern like ``vllm.{IdentPara}+{IdentPara}``\: Decoupled generation (vLLM) and training allocations with correspnding identical parallelization strategies. Note that the pipeline parallel degree of vLLM can only be 1.
|
||||
- A regex pattern like ``{vllm|sglang}.{IdentPara}+{IdentPara}``\: Decoupled generation and training allocations with correspnding identical parallelization strategies.
|
||||
|
||||
- Key-value pairs with MFC names and their parallel strategies in the whole cluster, e.g., ``actor_gen:d4m2p1,*:d2p2m2`` specifies a ``d4m2p1`` strategy for actor geneartion and ``d2p2m2`` for other MFCs in a world of 8 GPUs.
|
||||
|
||||
|
@ -363,6 +365,10 @@ class CommonExperimentConfig(Experiment):
|
|||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def _allocation_mode(self):
|
||||
return AllocationMode.from_str(self.allocation_mode)
|
||||
|
||||
def _get_rpc_allocations(self) -> List[RPCAllocation]:
|
||||
if self.allocation_mode == "manual" and self.nodelist is None:
|
||||
logger.warning(
|
||||
|
@ -373,9 +379,7 @@ class CommonExperimentConfig(Experiment):
|
|||
f"and n_gpus_per_node {self.n_gpus_per_node}."
|
||||
)
|
||||
|
||||
self.__check_legal_allocation_options()
|
||||
|
||||
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
|
||||
self._check_legal_allocation_options()
|
||||
|
||||
rpcs = self.rpcs
|
||||
if self.allocation_mode == "search":
|
||||
|
@ -452,9 +456,9 @@ class CommonExperimentConfig(Experiment):
|
|||
raise ValueError(
|
||||
"The multiplication of 3D parallel degrees "
|
||||
"does not equal to the number of gpus. "
|
||||
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
|
||||
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
|
||||
"so their summation should be equal to the total number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={gdp}, vllm.pp={gpp}, vllm.mp={gmp}, "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
alloc = RPCAllocation(
|
||||
|
@ -535,7 +539,7 @@ class CommonExperimentConfig(Experiment):
|
|||
def _get_model_worker_configs(
|
||||
self, rpc_allocs: List[RPCAllocation]
|
||||
) -> List[ModelWorker]:
|
||||
self.__run_model_sanity_check(rpc_allocs)
|
||||
self._run_model_sanity_check(rpc_allocs)
|
||||
|
||||
model_worker = []
|
||||
shard_counter = defaultdict(lambda: 0)
|
||||
|
@ -557,7 +561,7 @@ class CommonExperimentConfig(Experiment):
|
|||
tokenizer_name_or_path=self.tokenizer_name_or_path,
|
||||
)
|
||||
|
||||
# vLLM enabled model worker, shortcut case
|
||||
# decoupled allocation, shortcut case
|
||||
if (
|
||||
self._allocation_mode.is_decoupled()
|
||||
and self.gen_device_mesh.mapping[i, j]
|
||||
|
@ -574,22 +578,32 @@ class CommonExperimentConfig(Experiment):
|
|||
is_train=False,
|
||||
)
|
||||
model_cfg = self.models[model_name.role]
|
||||
global vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
|
||||
gen_backend_name = ""
|
||||
if self._allocation_mode.is_decoupled_vllm():
|
||||
gen_backend_name = "vllm"
|
||||
elif self._allocation_mode.is_decoupled_sglang():
|
||||
gen_backend_name = "sglang"
|
||||
backend_cfg = getattr(model_cfg, gen_backend_name)
|
||||
|
||||
global GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
if (
|
||||
model_cfg.vllm.hybrid_train
|
||||
and not vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
backend_cfg.hybrid_train
|
||||
and not GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
):
|
||||
logger.warning(
|
||||
"vLLM hybrid_train=True takes no effect for the decoupled allocation"
|
||||
"hybrid_train=True takes no effect for the decoupled allocation"
|
||||
)
|
||||
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
|
||||
model_cfg.vllm.hybrid_train = False
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
|
||||
backend_cfg.hybrid_train = False
|
||||
|
||||
if gen_backend_name == "vllm":
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
elif gen_backend_name == "sglang":
|
||||
check_valid_sglang(model_name.role, model_cfg.sglang, rpc_allocs)
|
||||
|
||||
shard_idx = shard_counter[model_name]
|
||||
vllm_dict_args: Dict[str, Any] = OmegaConf.to_container(
|
||||
model_cfg.vllm, resolve=True
|
||||
)
|
||||
dict_args: Dict[str, Any] = asdict(backend_cfg)
|
||||
mw.shards.append(
|
||||
StandaloneModelShardAbstraction(
|
||||
id=ModelShardID(
|
||||
|
@ -603,11 +617,11 @@ class CommonExperimentConfig(Experiment):
|
|||
"tokenizer", args=dict(tokenizer_path=model_cfg.path)
|
||||
),
|
||||
backend=ModelBackendAbstraction(
|
||||
"vllm",
|
||||
gen_backend_name,
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
dtype="bfloat16" if model_cfg.bf16 else "float16",
|
||||
**vllm_dict_args,
|
||||
**dict_args,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -684,29 +698,27 @@ class CommonExperimentConfig(Experiment):
|
|||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
|
||||
vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
assert (
|
||||
not model_cfg.sglang.hybrid_train
|
||||
), "vLLM and SGLang cannot be enabled at the same time"
|
||||
dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
backend = ModelBackendAbstraction(
|
||||
"vllm",
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
**vllm_dict_args,
|
||||
**dict_args,
|
||||
),
|
||||
)
|
||||
elif model_cfg.sglang.hybrid_train and any(
|
||||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SGLang hybrid_train=True is not supported yet."
|
||||
)
|
||||
else:
|
||||
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
|
||||
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
|
||||
"vllm",
|
||||
"sglang",
|
||||
]:
|
||||
print(rpcs, model_name, backend.type_)
|
||||
raise ValueError(
|
||||
"vLLM or SGLang is not enabled for generation. "
|
||||
"This behavior has been deprecated. "
|
||||
"Please set model.vllm.hybrid_train=True "
|
||||
"or model.sglang.hybrid_train=True."
|
||||
)
|
||||
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
if mapping[i, j]:
|
||||
shard_idx = shard_counter[model_name]
|
||||
mw.shards.append(
|
||||
|
@ -749,7 +761,7 @@ class CommonExperimentConfig(Experiment):
|
|||
evaluator=self.auto_eval_config,
|
||||
)
|
||||
|
||||
def __check_legal_allocation_options(self):
|
||||
def _check_legal_allocation_options(self):
|
||||
if self.n_nodes > 1 and self.mode == "local":
|
||||
raise ValueError(
|
||||
"Cannot run multi-node experiment in local mode, "
|
||||
|
@ -791,7 +803,7 @@ class CommonExperimentConfig(Experiment):
|
|||
f"RPC {rpc.name} model name {rpc.model_name.role} is not in models."
|
||||
)
|
||||
|
||||
def __run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
|
||||
def _run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
|
||||
for alloc in rpc_allocs:
|
||||
check_valid_parallel_batch_size(alloc)
|
||||
for role, model in self.models.items():
|
||||
|
|
|
@ -90,6 +90,7 @@ class PPOHyperparameters:
|
|||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
disable_value: bool = False
|
||||
recompute_logprob: bool = False
|
||||
max_reward_clip: float = 20.0
|
||||
reward_output_scaling: float = 1.0
|
||||
reward_output_bias: float = 0.0
|
||||
|
@ -189,15 +190,14 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
|
||||
# for manual allocation only
|
||||
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptOnlyDatasetConfig
|
||||
|
@ -248,30 +248,27 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
@property
|
||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
||||
# role to config
|
||||
models = {
|
||||
"actor": self.actor,
|
||||
"critic": self.critic,
|
||||
"ref": self.ref,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
# "critic": self.critic,
|
||||
"ref": self.ref,
|
||||
# "reward": self.rew,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
"critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
}
|
||||
models.pop("critic")
|
||||
return models
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
if (
|
||||
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
||||
(self._allocation_mode.is_decoupled_vllm() or self.actor.vllm.hybrid_train)
|
||||
and self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
||||
> self.actor.vllm.max_seq_len_to_capture
|
||||
and not self.actor.vllm.enforce_eager
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
|
||||
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
f"smaller than the prompt length + generation length "
|
||||
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
)
|
||||
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
|
||||
raise RuntimeError(
|
||||
|
@ -334,6 +331,14 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
+ 128,
|
||||
),
|
||||
)
|
||||
rollout_output_keys = [
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
]
|
||||
if self.ppo.recompute_logprob:
|
||||
rollout_output_keys.remove("packed_logprobs")
|
||||
rollout = MFCDef(
|
||||
name="actor_gen",
|
||||
model_name="actor",
|
||||
|
@ -343,32 +348,27 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts"],
|
||||
output_keys=[
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
],
|
||||
output_keys=rollout_output_keys,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
inf_reward = MFCDef(
|
||||
name="rew_inf",
|
||||
model_name="reward",
|
||||
mb_spec=self.rew_inf.mb_spec,
|
||||
actor_inf = MFCDef(
|
||||
name="actor_inf",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
model_type=self.rew.type,
|
||||
model_path=self.rew.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "packed_prompts"],
|
||||
output_keys=["rewards", "dense_rewards"],
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_input_ids"],
|
||||
output_keys=["packed_logprobs"],
|
||||
output_key_remap=dict(logprobs="packed_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
# add rew param into ref MFC
|
||||
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
|
||||
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
|
||||
inf_ref_outputs = ["logprobs", "rewards", "dense_rewards"]
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.type_ = "ref_rw"
|
||||
ref_interface.args["enable_save"] = False
|
||||
|
@ -385,6 +385,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=inf_ref_outputs,
|
||||
output_key_remap=dict(logprobs="packed_ref_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -450,47 +451,40 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
rpcs = {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
"critic_inf": inf_values,
|
||||
"critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"actor_inf": actor_inf,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
# "critic_inf": inf_values,
|
||||
# "critic_train": train_critic,
|
||||
# "ref_inf": inf_ref_logits,
|
||||
# "rew_inf": inf_reward,
|
||||
"ref_rw": inf_ref_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
"critic_inf": inf_values,
|
||||
"critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
rpcs.pop("critic_inf")
|
||||
rpcs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
rpcs.pop("actor_inf")
|
||||
return rpcs
|
||||
|
||||
@property
|
||||
def allocations(self):
|
||||
allocs = {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
"critic_inf": self.critic_inf,
|
||||
"critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"actor_inf": self.actor_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
# "critic_inf": self.critic_inf,
|
||||
# "critic_train": self.critic_train,
|
||||
# "ref_inf": self.ref_inf,
|
||||
# "rew_inf": self.rew_inf,
|
||||
"ref_rw": self.ref_inf,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
"critic_inf": self.critic_inf,
|
||||
"critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
allocs.pop("critic_inf")
|
||||
allocs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
allocs.pop("actor_inf")
|
||||
return allocs
|
||||
|
||||
@property
|
||||
def datasets(self):
|
||||
|
|
|
@ -110,32 +110,40 @@ def make_inf_backend_config(
|
|||
def resolve_replica_ids(
|
||||
rpc_allocs: List[RPCAllocation], models: Dict[str, ModelTrainEvalConfig]
|
||||
):
|
||||
role_cnt = collections.defaultdict(int)
|
||||
first_device_mesh = dict()
|
||||
first_parallel = dict()
|
||||
first_rpc = dict()
|
||||
role_rpcs = collections.defaultdict(list)
|
||||
for alloc in rpc_allocs:
|
||||
rpc = alloc.rpc
|
||||
if rpc.role not in first_device_mesh:
|
||||
first_device_mesh[rpc.role] = alloc.device_mesh
|
||||
first_parallel[rpc.role] = alloc.parallel
|
||||
first_rpc[rpc.role] = rpc
|
||||
role_rpcs[rpc.role].append(alloc)
|
||||
|
||||
for role, allocs in role_rpcs.items():
|
||||
cnt = len(allocs)
|
||||
if cnt == 1:
|
||||
allocs[0].rpc.model_name = ModelName(role, 0)
|
||||
continue
|
||||
model_cfg = models[rpc.role]
|
||||
if (rpc.is_train() and first_rpc[rpc.role].is_generate()) or (
|
||||
rpc.is_generate() and first_rpc[rpc.role].is_train()
|
||||
):
|
||||
if model_cfg.vllm.hybrid_train:
|
||||
role_cnt[rpc.role] += 1
|
||||
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
|
||||
rpcs = [alloc.rpc for alloc in allocs]
|
||||
if any(rpc.is_train() for rpc in rpcs):
|
||||
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_train())
|
||||
elif any(rpc.is_inference() for rpc in rpcs):
|
||||
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_inference())
|
||||
else:
|
||||
main_alloc = allocs[0]
|
||||
main_alloc.rpc.model_name = ModelName(role, 0)
|
||||
i = 1
|
||||
for alloc in allocs:
|
||||
if alloc.rpc.name == main_alloc.rpc.name:
|
||||
continue
|
||||
if alloc.device_mesh != first_device_mesh[rpc.role] or not parallelism_eq(
|
||||
alloc.parallel, first_parallel[rpc.role]
|
||||
):
|
||||
role_cnt[rpc.role] += 1
|
||||
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
|
||||
continue
|
||||
assert rpc.model_name.replica_id == 0
|
||||
same_alloc = alloc.device_mesh == main_alloc.device_mesh and parallelism_eq(
|
||||
alloc.parallel, main_alloc.parallel
|
||||
)
|
||||
if not same_alloc or (
|
||||
alloc.rpc.is_generate()
|
||||
and main_alloc.rpc.is_train()
|
||||
and (models[role].vllm.hybrid_train or models[role].sglang.hybrid_train)
|
||||
):
|
||||
alloc.rpc.model_name = ModelName(role, i)
|
||||
i += 1
|
||||
else:
|
||||
alloc.rpc.model_name = ModelName(role, 0)
|
||||
|
||||
|
||||
def resolve_rpc_hooks(
|
||||
|
@ -207,6 +215,7 @@ class AllocationType(enum.Enum):
|
|||
HEURISTIC = 4
|
||||
SEARCH = 5
|
||||
DECOUPLED_SGLANG = 6
|
||||
DECOUPLED_MOCK = 7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -218,6 +227,7 @@ class AllocationMode:
|
|||
return self.type_ in [
|
||||
AllocationType.DECOUPLED_vLLM,
|
||||
AllocationType.DECOUPLED_SGLANG,
|
||||
AllocationType.DECOUPLED_MOCK,
|
||||
]
|
||||
|
||||
def is_decoupled_vllm(self):
|
||||
|
@ -226,6 +236,9 @@ class AllocationMode:
|
|||
def is_decoupled_sglang(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_SGLANG
|
||||
|
||||
def is_decoupled_mock(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_MOCK
|
||||
|
||||
def is_global_hybrid(self):
|
||||
return self.type_ == AllocationType.GLOBAL_HYBRID
|
||||
|
||||
|
@ -246,6 +259,8 @@ class AllocationMode:
|
|||
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
|
||||
elif "sglang" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
|
||||
elif "mock" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_MOCK, alloc_decoupled)
|
||||
if alloc_3d:
|
||||
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
|
||||
if alloc_hybrid:
|
||||
|
@ -272,7 +287,7 @@ class AllocationMode:
|
|||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
||||
pattern = re.compile(
|
||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||
r"(?:(?:vllm|sglang|mock)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang|mock)\.(.+))"
|
||||
)
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
|
|
|
@ -14,21 +14,50 @@ from typing import *
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from realhf.api.core import model_api
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.api.quickstart.model import (
|
||||
DistributedDataParallelConfig,
|
||||
MegatronConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
|
||||
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
|
||||
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
|
||||
from realhf.impl.model.modules.mlp import get_activation_fn
|
||||
from realhf.impl.model.nn.flatten_param import ContiguousParamSpec
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
|
||||
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
|
||||
|
||||
try:
|
||||
# Monkey patch
|
||||
import megatron.core.optimizer as mcore_optim
|
||||
|
||||
class DistributedOptimizer(mcore_optim.DistributedOptimizer):
|
||||
def get_model_parallel_group(self):
|
||||
return constants.parallelism_group()
|
||||
|
||||
mcore_optim.DistributedOptimizer = DistributedOptimizer
|
||||
|
||||
from megatron.core import parallel_state
|
||||
from megatron.core.distributed.distributed_data_parallel import (
|
||||
DistributedDataParallel,
|
||||
)
|
||||
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
|
||||
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
|
||||
from megatron.core.optimizer.clip_grads import clip_grad_norm_fp32, count_zeros_fp32
|
||||
from megatron.core.optimizer.optimizer_config import (
|
||||
OptimizerConfig as MegatronOptimizerConfig,
|
||||
)
|
||||
from megatron.core.transformer.transformer_config import (
|
||||
TransformerConfig as MegatronTransformerConfig,
|
||||
)
|
||||
|
||||
megatron_available = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# importing megatron.core in CPU container will fail due to the requirement of apex
|
||||
# Here class types must be defined for type hinting
|
||||
|
@ -42,19 +71,19 @@ except (ModuleNotFoundError, ImportError):
|
|||
pass
|
||||
|
||||
|
||||
from realhf.api.core import model_api
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.api.quickstart.model import MegatronConfig, OptimizerConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
|
||||
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
|
||||
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
|
||||
from realhf.impl.model.modules.mlp import get_activation_fn
|
||||
from realhf.impl.model.nn.flatten_param import ContiguousParamSpec
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
|
||||
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
|
||||
if megatron_available:
|
||||
try:
|
||||
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
|
||||
|
||||
use_old_megatron = False
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# The above object is available in 0.9.0 but missing in 0.6.0
|
||||
from realhf.impl.model.backend.thirdparty.megatron.v0_6_0.lr_schduler import (
|
||||
OptimizerParamScheduler,
|
||||
)
|
||||
|
||||
use_old_megatron = True
|
||||
|
||||
|
||||
WITHIN_MEGATRON_CONTEXT = False
|
||||
|
||||
|
@ -122,6 +151,7 @@ def megatron_ctx():
|
|||
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = constants.model_parallel_group()
|
||||
g = constants.data_parallel_group()
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = g
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = g
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
|
||||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
|
@ -166,462 +196,62 @@ class MegatronEngine:
|
|||
self.ddp.zero_grad_buffer()
|
||||
self.optim.zero_grad(set_to_none=set_to_none)
|
||||
|
||||
def _all_reduce_layernorm_grads(self):
|
||||
if not (
|
||||
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
|
||||
):
|
||||
return
|
||||
real_model: ReaLModel = self.ddp.module
|
||||
grads = []
|
||||
for i in range(real_model.layer_idx_start, real_model.layer_idx_end):
|
||||
if i == 0:
|
||||
continue
|
||||
elif i == real_model.config.n_layers + 1:
|
||||
continue
|
||||
else:
|
||||
assert 0 < i < real_model.config.n_layers + 1
|
||||
layer: ReaLModelBlock = real_model.layers[
|
||||
i - real_model.layer_idx_start
|
||||
]
|
||||
grads.append(layer.attn.c_attn.ln.weight.main_grad)
|
||||
if getattr(layer.attn.c_attn.ln, "bias", None) is not None:
|
||||
grads.append(layer.attn.c_attn.ln.bias.main_grad)
|
||||
grads.append(layer.mlp.ln.weight.main_grad)
|
||||
if getattr(layer.mlp.ln, "bias", None) is not None:
|
||||
grads.append(layer.mlp.ln.bias.main_grad)
|
||||
if i == real_model.config.n_layers:
|
||||
grads.append(layer.ln_f.weight.main_grad)
|
||||
if getattr(layer.ln_f, "bias", None) is not None:
|
||||
grads.append(layer.ln_f.bias.main_grad)
|
||||
|
||||
# Adopted from Megatron-LM/megatron/training/optimizer_param_scheduler.py
|
||||
class OptimizerParamScheduler(object):
|
||||
"""Anneals learning rate and weight decay.
|
||||
assert all(x is not None for x in grads)
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, group=constants.model_parallel_group())
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
Adopted from Megatron-LM. This class is not included in
|
||||
megatron.core, so we have to copy-paste it here.
|
||||
"""
|
||||
def _all_reduce_word_embedding_grads(self):
|
||||
real_model: ReaLModel = self.ddp.module
|
||||
if not real_model.config.tied_embedding or real_model.config.is_critic:
|
||||
return
|
||||
pp_size = constants.pipe_parallel_world_size()
|
||||
pp_rank = constants.pipe_parallel_rank()
|
||||
if pp_size == 1:
|
||||
return
|
||||
if pp_rank not in [0, pp_size - 1]:
|
||||
return
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
init_lr,
|
||||
max_lr,
|
||||
min_lr,
|
||||
lr_warmup_steps,
|
||||
lr_decay_steps,
|
||||
lr_decay_style,
|
||||
start_wd,
|
||||
end_wd,
|
||||
wd_incr_steps,
|
||||
wd_incr_style,
|
||||
use_checkpoint_opt_param_scheduler=True,
|
||||
override_opt_param_scheduler=False,
|
||||
):
|
||||
|
||||
# Class values.
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.max_lr = float(max_lr)
|
||||
self.min_lr = min_lr
|
||||
assert self.min_lr >= 0.0
|
||||
assert self.max_lr >= self.min_lr
|
||||
assert self.init_lr <= self.max_lr
|
||||
|
||||
self.lr_warmup_steps = lr_warmup_steps
|
||||
self.num_steps = 0
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
assert self.lr_decay_steps > 0
|
||||
assert self.lr_warmup_steps < self.lr_decay_steps
|
||||
|
||||
self.lr_decay_style = lr_decay_style
|
||||
|
||||
self.start_wd = start_wd
|
||||
self.end_wd = end_wd
|
||||
assert self.start_wd >= 0.0
|
||||
assert self.end_wd >= self.start_wd
|
||||
self.wd_incr_steps = wd_incr_steps
|
||||
self.wd_incr_style = wd_incr_style
|
||||
|
||||
self.override_opt_param_scheduler = override_opt_param_scheduler
|
||||
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
|
||||
if self.override_opt_param_scheduler:
|
||||
assert not self.use_checkpoint_opt_param_scheduler, (
|
||||
"both override and " "use-checkpoint are set."
|
||||
)
|
||||
|
||||
# Set the learning rate
|
||||
self.step(0)
|
||||
self.log_rank_0("> learning rate decay style: {}".format(self.lr_decay_style))
|
||||
|
||||
def log_rank_0(self, msg):
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(msg)
|
||||
|
||||
def get_wd(self):
|
||||
"""Weight decay incr functions."""
|
||||
if self.num_steps > self.wd_incr_steps:
|
||||
return self.end_wd
|
||||
|
||||
if self.wd_incr_style == "constant":
|
||||
assert self.start_wd == self.end_wd
|
||||
return self.end_wd
|
||||
|
||||
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
|
||||
assert incr_ratio >= 0.0
|
||||
assert incr_ratio <= 1.0
|
||||
delta_wd = self.end_wd - self.start_wd
|
||||
|
||||
if self.wd_incr_style == "linear":
|
||||
coeff = incr_ratio
|
||||
elif self.wd_incr_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
|
||||
if pp_rank == 0:
|
||||
grad = real_model.layers[0].wte.weight.main_grad
|
||||
else:
|
||||
raise Exception(
|
||||
"{} weight decay increment style is not supported.".format(
|
||||
self.wd_incr_style
|
||||
)
|
||||
)
|
||||
grad = real_model.layers[-1].weight.main_grad
|
||||
|
||||
return self.start_wd + coeff * delta_wd
|
||||
dist.all_reduce(grad, group=constants.grid().embedding_proc_group)
|
||||
|
||||
def get_lr(self, param_group):
|
||||
"""Learning rate decay functions from:
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
|
||||
max_lr = param_group.get("max_lr", self.max_lr)
|
||||
min_lr = param_group.get("min_lr", self.min_lr)
|
||||
|
||||
# Use linear warmup for the initial part.
|
||||
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
|
||||
return self.init_lr + (
|
||||
(max_lr - self.init_lr)
|
||||
* float(self.num_steps)
|
||||
/ float(self.lr_warmup_steps)
|
||||
)
|
||||
|
||||
# If the learning rate is constant, just return the initial value.
|
||||
if self.lr_decay_style == "constant":
|
||||
return max_lr
|
||||
|
||||
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
|
||||
if self.num_steps > self.lr_decay_steps:
|
||||
return min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
if self.lr_decay_style == "inverse-square-root":
|
||||
warmup_steps = max(self.lr_warmup_steps, 1)
|
||||
num_steps = max(self.num_steps, 1)
|
||||
lr = max_lr * warmup_steps**0.5 / (num_steps**0.5)
|
||||
return max(min_lr, lr)
|
||||
|
||||
num_steps_ = self.num_steps - self.lr_warmup_steps
|
||||
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = max_lr - min_lr
|
||||
|
||||
if self.lr_decay_style == "linear":
|
||||
coeff = 1.0 - decay_ratio
|
||||
elif self.lr_decay_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
else:
|
||||
raise Exception(
|
||||
"{} decay style is not supported.".format(self.lr_decay_style)
|
||||
)
|
||||
|
||||
return min_lr + coeff * delta_lr
|
||||
|
||||
def step_absolute(self, num_steps):
|
||||
"""Set lr for all parameters groups."""
|
||||
if num_steps is None:
|
||||
self.num_steps += 1
|
||||
else:
|
||||
self.num_steps = num_steps
|
||||
new_wd = self.get_wd()
|
||||
for param_group in self.optimizer.param_groups:
|
||||
new_lr = self.get_lr(param_group)
|
||||
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
|
||||
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
|
||||
|
||||
def step(self, increment):
|
||||
"""Set lr for all parameters groups."""
|
||||
self.num_steps += increment
|
||||
new_wd = self.get_wd()
|
||||
for param_group in self.optimizer.param_groups:
|
||||
new_lr = self.get_lr(param_group)
|
||||
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
|
||||
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"max_lr": self.max_lr,
|
||||
"lr_warmup_steps": self.lr_warmup_steps,
|
||||
"num_steps": self.num_steps,
|
||||
"lr_decay_style": self.lr_decay_style,
|
||||
"lr_decay_steps": self.lr_decay_steps,
|
||||
"min_lr": self.min_lr,
|
||||
"start_wd": self.start_wd,
|
||||
"end_wd": self.end_wd,
|
||||
"wd_incr_style": self.wd_incr_style,
|
||||
"wd_incr_steps": self.wd_incr_steps,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def _check_and_set(self, cls_value, sd_value, name):
|
||||
"""Auxiliary function for checking the values in the checkpoint and
|
||||
setting them."""
|
||||
if self.override_opt_param_scheduler:
|
||||
self.log_rank_0(" > overriding {} value to {}".format(name, cls_value))
|
||||
return cls_value
|
||||
|
||||
if not self.use_checkpoint_opt_param_scheduler:
|
||||
assert cls_value == sd_value, (
|
||||
f"OptimizerParamScheduler: class input value {cls_value} and checkpoint"
|
||||
f"value {sd_value} for {name} do not match"
|
||||
)
|
||||
self.log_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
|
||||
return sd_value
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
|
||||
if "start_lr" in sd:
|
||||
max_lr_ = sd["start_lr"]
|
||||
else:
|
||||
max_lr_ = sd["max_lr"]
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
|
||||
|
||||
self.min_lr = self._check_and_set(
|
||||
self.min_lr, sd["min_lr"], "minimum learning rate"
|
||||
)
|
||||
|
||||
if "warmup_iter" in sd:
|
||||
lr_warmup_steps_ = sd["warmup_iter"]
|
||||
elif "warmup_steps" in sd:
|
||||
lr_warmup_steps_ = sd["warmup_steps"]
|
||||
else:
|
||||
lr_warmup_steps_ = sd["lr_warmup_steps"]
|
||||
self.lr_warmup_steps = self._check_and_set(
|
||||
self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations"
|
||||
)
|
||||
|
||||
if "end_iter" in sd:
|
||||
lr_decay_steps_ = sd["end_iter"]
|
||||
elif "decay_steps" in sd:
|
||||
lr_decay_steps_ = sd["decay_steps"]
|
||||
else:
|
||||
lr_decay_steps_ = sd["lr_decay_steps"]
|
||||
self.lr_decay_steps = self._check_and_set(
|
||||
self.lr_decay_steps, lr_decay_steps_, "total number of iterations"
|
||||
)
|
||||
|
||||
if "decay_style" in sd:
|
||||
lr_decay_style_ = sd["decay_style"]
|
||||
else:
|
||||
lr_decay_style_ = sd["lr_decay_style"]
|
||||
self.lr_decay_style = self._check_and_set(
|
||||
self.lr_decay_style, lr_decay_style_, "learning rate decay style"
|
||||
)
|
||||
|
||||
if "num_iters" in sd:
|
||||
num_steps = sd["num_iters"]
|
||||
else:
|
||||
num_steps = sd["num_steps"]
|
||||
self.step(increment=num_steps)
|
||||
|
||||
if "start_wd" in sd:
|
||||
self.start_wd = self._check_and_set(
|
||||
self.start_wd, sd["start_wd"], "start weight decay"
|
||||
)
|
||||
self.end_wd = self._check_and_set(
|
||||
self.end_wd, sd["end_wd"], "end weight decay"
|
||||
)
|
||||
self.wd_incr_steps = self._check_and_set(
|
||||
self.wd_incr_steps,
|
||||
sd["wd_incr_steps"],
|
||||
"total number of weight decay iterations",
|
||||
)
|
||||
self.wd_incr_style = self._check_and_set(
|
||||
self.wd_incr_style,
|
||||
sd["wd_incr_style"],
|
||||
"weight decay incr style",
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _step_megatron_distrib_optimizer_internal(optim: DistributedOptimizer):
|
||||
# NOTE: patching this function to use the correct model parallel group
|
||||
|
||||
optim._copy_model_grads_to_main_grads()
|
||||
|
||||
# Do unscale, check for inf, and update grad scaler only for
|
||||
# the case that grad scaler is provided.
|
||||
if optim.grad_scaler:
|
||||
|
||||
def _unscale_main_grads_and_check_for_nan(optim: DistributedOptimizer):
|
||||
|
||||
# Collect main grads.
|
||||
main_grads = optim._collect_main_grad_data_for_unscaling()
|
||||
|
||||
# Reset found inf.
|
||||
optim.found_inf.fill_(0.0)
|
||||
|
||||
# Unscale and set found inf/nan
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||
main_grads, optim.found_inf, optim.grad_scaler.inv_scale
|
||||
)
|
||||
|
||||
# Update across all model parallel instances.
|
||||
dist.all_reduce(
|
||||
optim.found_inf,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=constants.parallelism_group(),
|
||||
)
|
||||
|
||||
# Check for nan.
|
||||
found_inf_flag = optim.found_inf.item() > 0
|
||||
|
||||
return found_inf_flag
|
||||
|
||||
# Unscale and check for inf/nan.
|
||||
found_inf_flag = _unscale_main_grads_and_check_for_nan(optim)
|
||||
|
||||
# We are done with scaling gradients
|
||||
# so we can update the loss scale.
|
||||
optim.grad_scaler.update(found_inf_flag)
|
||||
|
||||
# If we found inf/nan, skip the update.
|
||||
if found_inf_flag:
|
||||
optim.update_successful, grad_norm, num_zeros_in_grad = (
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
return optim.update_successful, grad_norm, num_zeros_in_grad
|
||||
|
||||
def clip_grad_norm(optim: DistributedOptimizer, clip_grad: float) -> float:
|
||||
"""Compute grad norm."""
|
||||
params = optim.get_parameters()
|
||||
grads_for_norm = optim.get_main_grads_for_grad_norm()
|
||||
return clip_grad_norm_fp32(
|
||||
params,
|
||||
grads_for_norm,
|
||||
clip_grad,
|
||||
model_parallel_group=constants.parallelism_group(),
|
||||
)
|
||||
|
||||
def count_zeros(optim: DistributedOptimizer) -> float:
|
||||
"""Count number of zeros in model's gradients."""
|
||||
params = optim.get_parameters()
|
||||
return count_zeros_fp32(
|
||||
params,
|
||||
model_parallel_group=constants.parallelism_group(),
|
||||
)
|
||||
|
||||
# Clip the main gradients.
|
||||
grad_norm = None
|
||||
if optim.config.clip_grad > 0.0:
|
||||
grad_norm = clip_grad_norm(optim, optim.config.clip_grad)
|
||||
|
||||
# Count the zeros in the grads.
|
||||
num_zeros_in_grad = None
|
||||
if optim.config.log_num_zeros_in_grad:
|
||||
num_zeros_in_grad = count_zeros(optim)
|
||||
|
||||
# Step the optimizer.
|
||||
optim.optimizer.step()
|
||||
|
||||
# Update params from main params.
|
||||
optim._copy_main_params_to_model_params()
|
||||
# Successful update.
|
||||
optim.update_successful, grad_norm, num_zeros_in_grad = (
|
||||
True,
|
||||
grad_norm,
|
||||
num_zeros_in_grad,
|
||||
)
|
||||
return optim.update_successful, grad_norm, num_zeros_in_grad
|
||||
|
||||
|
||||
def step_megatron_distrb_optimizer(optim: DistributedOptimizer):
|
||||
|
||||
optim.update_successful, grad_norm, num_zeros_in_grad = (
|
||||
_step_megatron_distrib_optimizer_internal(optim)
|
||||
)
|
||||
|
||||
# If not overlapping all-gather for parameters, launch synchronous all-gather
|
||||
# communication calls here. If overlapping all-gather for parameters, the following
|
||||
# call to _gather_all_model_params is a no-op: the first all-gather is launched
|
||||
# asynchronously in the next optimizer.zero_grad() call and subsequent all-gathers
|
||||
# are launched in the forward pre-hook.
|
||||
optim._reset_metadata_and_sync_gather_all_model_params(force_sync=False)
|
||||
|
||||
return optim.update_successful, grad_norm, num_zeros_in_grad
|
||||
|
||||
|
||||
def _flatten_dense_tensors(tensors):
|
||||
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
||||
same dense type.
|
||||
|
||||
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
||||
buffer. Element-wise operation on this buffer will be equivalent to
|
||||
operating individually.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[Tensor]): dense tensors to flatten.
|
||||
|
||||
Returns:
|
||||
A contiguous 1D buffer containing input tensors.
|
||||
"""
|
||||
return torch._C._nn.flatten_dense_tensors(tensors)
|
||||
|
||||
|
||||
def _unflatten_dense_tensors(flat, tensors):
|
||||
"""View a flat buffer using the sizes of tensors. Assume that tensors are
|
||||
of same dense type, and that flat is given by _flatten_dense_tensors.
|
||||
|
||||
Args:
|
||||
flat (Tensor): flattened dense tensors to unflatten.
|
||||
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
||||
unflatten flat.
|
||||
|
||||
Returns:
|
||||
Unflattened dense tensors with sizes same as tensors and values from
|
||||
flat.
|
||||
"""
|
||||
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
|
||||
|
||||
|
||||
def megatron_all_reduce_layernorm_grads(engine: MegatronEngine):
|
||||
if not (
|
||||
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
|
||||
):
|
||||
return
|
||||
real_model: ReaLModel = engine.ddp.module
|
||||
grads = []
|
||||
for i in range(real_model.layer_idx_start, real_model.layer_idx_end):
|
||||
if i == 0:
|
||||
continue
|
||||
elif i == real_model.config.n_layers + 1:
|
||||
continue
|
||||
else:
|
||||
assert 0 < i < real_model.config.n_layers + 1
|
||||
layer: ReaLModelBlock = real_model.layers[i - real_model.layer_idx_start]
|
||||
grads.append(layer.attn.c_attn.ln.weight.main_grad)
|
||||
if getattr(layer.attn.c_attn.ln, "bias", None) is not None:
|
||||
grads.append(layer.attn.c_attn.ln.bias.main_grad)
|
||||
grads.append(layer.mlp.ln.weight.main_grad)
|
||||
if getattr(layer.mlp.ln, "bias", None) is not None:
|
||||
grads.append(layer.mlp.ln.bias.main_grad)
|
||||
if i == real_model.config.n_layers:
|
||||
grads.append(layer.ln_f.weight.main_grad)
|
||||
if getattr(layer.ln_f, "bias", None) is not None:
|
||||
grads.append(layer.ln_f.bias.main_grad)
|
||||
|
||||
assert all(x is not None for x in grads)
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, group=constants.model_parallel_group())
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
def megatron_all_reduce_word_embedding_grads(engine: MegatronEngine):
|
||||
real_model: ReaLModel = engine.ddp.module
|
||||
if not real_model.config.tied_embedding or real_model.config.is_critic:
|
||||
return
|
||||
pp_size = constants.pipe_parallel_world_size()
|
||||
pp_rank = constants.pipe_parallel_rank()
|
||||
if pp_size == 1:
|
||||
return
|
||||
if pp_rank not in [0, pp_size - 1]:
|
||||
return
|
||||
|
||||
if pp_rank == 0:
|
||||
grad = real_model.layers[0].wte.weight.main_grad
|
||||
else:
|
||||
grad = real_model.layers[-1].weight.main_grad
|
||||
|
||||
dist.all_reduce(grad, group=constants.grid().embedding_proc_group)
|
||||
|
||||
|
||||
def finalize_grads_megatron(engine: MegatronEngine):
|
||||
engine.ddp.finish_grad_sync()
|
||||
megatron_all_reduce_layernorm_grads(engine)
|
||||
megatron_all_reduce_word_embedding_grads(engine)
|
||||
def finalize_grads(self):
|
||||
self.ddp.finish_grad_sync()
|
||||
self._all_reduce_layernorm_grads()
|
||||
self._all_reduce_word_embedding_grads()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -681,7 +311,7 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
|||
step_id: int,
|
||||
):
|
||||
# self.engine.ddp.start_grad_sync()
|
||||
finalize_grads_megatron(self.engine)
|
||||
self.engine.finalize_grads()
|
||||
|
||||
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
|
||||
def _exec_optimizer_step(
|
||||
|
@ -692,16 +322,12 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
|||
micro_batch_id: int,
|
||||
step_id: int,
|
||||
):
|
||||
if isinstance(self.engine.optim, DistributedOptimizer):
|
||||
update_successful, grad_norm, num_zeros_in_grad = (
|
||||
step_megatron_distrb_optimizer(self.engine.optim)
|
||||
)
|
||||
else:
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.engine.optim.step()
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.engine.optim.step()
|
||||
|
||||
version_steps = tensor_buffer.get("version_steps", 0)
|
||||
if update_successful:
|
||||
self.engine.lr_scheduler.step_absolute(version_steps)
|
||||
incr = version_steps - self.engine.lr_scheduler.num_steps
|
||||
self.engine.lr_scheduler.step(incr)
|
||||
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
|
||||
logger.info(
|
||||
f"Model name {constants.model_name()}, "
|
||||
|
@ -824,7 +450,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
for k, v in _stat.items():
|
||||
stat[k] += v
|
||||
|
||||
finalize_grads_megatron(self.engine)
|
||||
self.engine.finalize_grads()
|
||||
self._step(version_steps)
|
||||
|
||||
return stat
|
||||
|
@ -834,12 +460,14 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
self,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
):
|
||||
return self.inf_engine.forward(
|
||||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
output_seqlens=output_seqlens,
|
||||
post_hook=post_hook,
|
||||
aggregate_fn=aggregate_fn,
|
||||
)
|
||||
|
@ -864,14 +492,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
# wrapper for profiler
|
||||
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
|
||||
def _step(self, version_steps):
|
||||
if isinstance(self.engine.optim, DistributedOptimizer):
|
||||
update_successful, grad_norm, _ = step_megatron_distrb_optimizer(
|
||||
self.engine.optim
|
||||
)
|
||||
else:
|
||||
update_successful, grad_norm, _ = self.engine.optim.step()
|
||||
update_successful, grad_norm, _ = self.engine.optim.step()
|
||||
if update_successful:
|
||||
self.engine.lr_scheduler.step_absolute(version_steps)
|
||||
incr = version_steps - self.engine.lr_scheduler.num_steps
|
||||
self.engine.lr_scheduler.step(incr)
|
||||
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
|
||||
logger.info(
|
||||
f"Megatron backend update success? {update_successful}. "
|
||||
|
@ -884,10 +508,6 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
@dataclasses.dataclass
|
||||
class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
||||
bf16: bool = False
|
||||
zero_stage: int = dataclasses.field(
|
||||
metadata={"choices": [0, 1, 2, 3]},
|
||||
default=2,
|
||||
)
|
||||
optimizer: OptimizerConfig = dataclasses.field(default_factory=OptimizerConfig)
|
||||
|
||||
def _initialize(
|
||||
|
@ -896,21 +516,32 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
|
||||
if isinstance(self.ddp, dict):
|
||||
self.ddp = DistributedDataParallelConfig(**self.ddp)
|
||||
with megatron_ctx():
|
||||
module = DistributedDataParallel(
|
||||
config=get_megatron_transformer_config(module.config),
|
||||
module=module,
|
||||
data_parallel_group=constants.data_parallel_group(),
|
||||
accumulate_allreduce_grads_in_fp32=self.accumulate_allreduce_grads_in_fp32,
|
||||
overlap_grad_reduce=self.overlap_grad_reduce,
|
||||
use_distributed_optimizer=self.zero_stage > 0,
|
||||
expert_data_parallel_group=None,
|
||||
disable_bucketing=False,
|
||||
check_for_nan_in_grad=False,
|
||||
)
|
||||
if use_old_megatron:
|
||||
module = DistributedDataParallel(
|
||||
config=get_megatron_transformer_config(module.config),
|
||||
module=module,
|
||||
data_parallel_group=constants.data_parallel_group(),
|
||||
accumulate_allreduce_grads_in_fp32=self.ddp.grad_reduce_in_fp32,
|
||||
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
|
||||
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
|
||||
expert_data_parallel_group=None,
|
||||
disable_bucketing=False,
|
||||
check_for_nan_in_grad=self.ddp.check_for_nan_in_grad,
|
||||
bucket_size=self.ddp.bucket_size,
|
||||
)
|
||||
else:
|
||||
module = DistributedDataParallel(
|
||||
config=get_megatron_transformer_config(module.config),
|
||||
ddp_config=self.ddp,
|
||||
module=module,
|
||||
disable_bucketing=False,
|
||||
)
|
||||
|
||||
real_model: ReaLModel = module.module
|
||||
if self.zero_stage > 0:
|
||||
if self.ddp.use_distributed_optimizer:
|
||||
# Remap parameters.
|
||||
assert len(module.buffers) == 1
|
||||
param_grad_buf: ParamAndGradBuffer = module.buffers[0]
|
||||
|
@ -946,22 +577,23 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
opt_cfg = MegatronOptimizerConfig(
|
||||
optimizer=self.optimizer.type,
|
||||
bf16=self.bf16,
|
||||
fp16=not self.bf16,
|
||||
lr=lr,
|
||||
min_lr=self.optimizer.min_lr_ratio * lr,
|
||||
weight_decay=wd,
|
||||
params_dtype=real_model.dtype,
|
||||
initial_loss_scale=self.optimizer.initial_loss_scale,
|
||||
adam_beta1=betas[0],
|
||||
adam_beta2=betas[1],
|
||||
adam_eps=self.optimizer.eps,
|
||||
sgd_momentum=0.9,
|
||||
use_distributed_optimizer=self.zero_stage > 0,
|
||||
overlap_grad_reduce=self.overlap_grad_reduce,
|
||||
overlap_param_gather=self.overlap_param_gather,
|
||||
clip_grad=self.optimizer.gradient_clipping,
|
||||
min_loss_scale=self.optimizer.min_loss_scale,
|
||||
loss_scale_window=self.optimizer.loss_scale_window,
|
||||
hysteresis=self.optimizer.hysteresis,
|
||||
adam_beta1=betas[0],
|
||||
adam_beta2=betas[1],
|
||||
adam_eps=self.optimizer.eps,
|
||||
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
|
||||
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
|
||||
overlap_param_gather=self.ddp.overlap_param_gather,
|
||||
clip_grad=self.optimizer.gradient_clipping,
|
||||
log_num_zeros_in_grad=False,
|
||||
)
|
||||
|
||||
with megatron_ctx():
|
||||
|
@ -1005,7 +637,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
# create circular references (grad -> param -> grad).
|
||||
# Deleting models directly will not release the memory.
|
||||
# We must disable hooks at first.
|
||||
if self.zero_stage > 0 and self.overlap_param_gather:
|
||||
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
|
||||
optimizer.disable_pre_hook()
|
||||
|
||||
def save(self, model: model_api.Model, save_dir: str):
|
||||
|
|
|
@ -179,12 +179,14 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
self,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
):
|
||||
return self.inf_engine.forward(
|
||||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
output_seqlens=output_seqlens,
|
||||
post_hook=post_hook,
|
||||
aggregate_fn=aggregate_fn,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,453 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib.metadata import version
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from realhf.api.core import data_api
|
||||
from realhf.api.core.model_api import (
|
||||
APIGenerateInput,
|
||||
APIGenerateOutput,
|
||||
FinetuneSpec,
|
||||
GenerationHyperparameters,
|
||||
LLMAPIClient,
|
||||
Model,
|
||||
ModelBackend,
|
||||
PipelinableEngine,
|
||||
register_backend,
|
||||
)
|
||||
from realhf.api.quickstart.model import SGLangConfig
|
||||
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
|
||||
|
||||
logger = logging.getLogger("SGLang backend")
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
class SGLangAPIClient(LLMAPIClient):
|
||||
|
||||
async def _do_generate(
|
||||
self, req: APIGenerateInput, stream: bool = False
|
||||
) -> APIGenerateOutput:
|
||||
gconfig = req.gconfig
|
||||
sample_params = {
|
||||
"n": gconfig.n,
|
||||
"top_p": gconfig.top_p,
|
||||
"top_k": gconfig.top_k,
|
||||
"max_new_tokens": gconfig.max_new_tokens,
|
||||
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
||||
"stop_token_ids": req.stop_token_ids,
|
||||
}
|
||||
payload = {
|
||||
"input_ids": req.prompt_ids,
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": req.return_logprob,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
output = APIGenerateOutput.from_input(req)
|
||||
|
||||
# The following code is partially adopted from sglang/bench_serving.py
|
||||
output_ids = []
|
||||
output_logprobs = []
|
||||
finish_reason = {}
|
||||
ttft = 0.0
|
||||
latency = float("inf")
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=None)
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=self.generate_url,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["token_ids"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output_ids = data["token_ids"]
|
||||
finish_reason = data["meta_info"]["finish_reason"]
|
||||
output_logprobs = data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
]
|
||||
|
||||
assert finish_reason["type"] in ["length", "stop"], finish_reason
|
||||
output.output_logprobs = [x[0] for x in output_logprobs]
|
||||
output.output_ids = output_ids
|
||||
output.no_eos = finish_reason["type"] == "length"
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
raise RuntimeError(
|
||||
f"SGLang generation request fails:\n{output.error}"
|
||||
) from e
|
||||
|
||||
return output
|
||||
|
||||
async def async_update_weights_from_disk(self, path):
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=None)
|
||||
async with self.session.post(
|
||||
url=self.update_weights_url,
|
||||
json=dict(model_path=path),
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Update weights failed.")
|
||||
|
||||
|
||||
def sglang_server_process(server_args_dict):
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
sglang_version = version("sglang")
|
||||
if Version(sglang_version) < Version("0.4.3"):
|
||||
from sglang.srt.server import launch_server
|
||||
|
||||
server_args_dict.pop("enable_nccl_nvls")
|
||||
server_args_dict.pop("triton_attention_num_kv_splits")
|
||||
server_args_dict.pop("cuda_graph_bs")
|
||||
server_args_dict.pop("enable_memory_saver")
|
||||
server_args_dict.pop("allow_auto_truncate")
|
||||
server_args_dict.pop("return_hidden_states")
|
||||
else:
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
|
||||
server_args = ServerArgs(**server_args_dict)
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, list(range(gpu_utils.gpu_count())))
|
||||
)
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
class SGLangGenerationEngine(PipelinableEngine):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args_dict: Dict,
|
||||
hybrid_train: bool,
|
||||
request_timeout: int = 1800,
|
||||
):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return
|
||||
# Start the serving process
|
||||
self.server_proc = mp.Process(
|
||||
target=sglang_server_process,
|
||||
args=(server_args_dict,),
|
||||
)
|
||||
self.server_proc.start()
|
||||
|
||||
self.base_url = f"http://{server_args_dict['host']}:{server_args_dict['port']}"
|
||||
|
||||
self.api_urls = {
|
||||
"generate": f"{self.base_url}/generate",
|
||||
"offload_weights": f"{self.base_url}/offload_weights",
|
||||
"init_kv_cache": f"{self.base_url}/init_kv_cache",
|
||||
"clear_kv_cache": f"{self.base_url}/clear_kv_cache",
|
||||
"init_model_weights": f"{self.base_url}/init_model_weights",
|
||||
"update_weights_from_disk": f"{self.base_url}/update_weights_from_disk",
|
||||
}
|
||||
|
||||
asyncio.run(self.wait_server())
|
||||
|
||||
self.request_timeout = request_timeout
|
||||
|
||||
# offload weights/cache
|
||||
self.hybrid_train = hybrid_train
|
||||
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "server_proc"):
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
self.server_proc.terminate()
|
||||
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
# NOTE: A placeholder function.
|
||||
def train(self, mode: bool = True):
|
||||
return self
|
||||
|
||||
# NOTE: A placeholder function.
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
async def wait_server(self):
|
||||
# Wait until the server is launched
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
success = False
|
||||
for _ in range(120):
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
res = requests.get(
|
||||
self.base_url + "/get_model_info", timeout=5, headers={}
|
||||
)
|
||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||
success = True
|
||||
break
|
||||
except (AssertionError, requests.exceptions.RequestException):
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
if not success:
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
input_: data_api.SequenceSample,
|
||||
mb_spec: data_api.MicroBatchSpec,
|
||||
tokenizer: transformers.PreTrainedTokenizerFast,
|
||||
gconfig: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
),
|
||||
stream: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=input_.bs * gconfig.n)
|
||||
|
||||
async with SGLangAPIClient(
|
||||
generate_url=self.api_urls["generate"],
|
||||
update_weights_url=self.api_urls["update_weights_from_disk"],
|
||||
) as client:
|
||||
tasks = []
|
||||
input_queries = []
|
||||
for d in input_.unpack():
|
||||
if len(d.seqlens["packed_input_ids"]) > 1:
|
||||
raise RuntimeError(
|
||||
f"sglang backend does not support grouped generation "
|
||||
f"for now. Group size {len(d.seqlens['packed_input_ids'])}."
|
||||
)
|
||||
|
||||
prompt_token_ids = d.data["packed_input_ids"].cpu().numpy().tolist()
|
||||
qid = d.ids[0]
|
||||
for group_idx in range(gconfig.n):
|
||||
req = APIGenerateInput(
|
||||
qid=qid,
|
||||
group_idx=group_idx,
|
||||
prompt_ids=prompt_token_ids,
|
||||
input_ids=prompt_token_ids,
|
||||
gconfig=gconfig.new(n=1),
|
||||
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
|
||||
return_logprob=True,
|
||||
)
|
||||
input_queries.append((qid, group_idx))
|
||||
tasks.append(
|
||||
client.async_add_generate_request(
|
||||
req,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
for r in asyncio.as_completed(tasks):
|
||||
out = await r
|
||||
outputs[(out.qid, out.group_idx)] = out
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
results: List[APIGenerateOutput] = [outputs[key] for key in input_queries]
|
||||
|
||||
# Build the output: generated token ids, generated token scores,
|
||||
# and logits mask (which will always be None in sglang).
|
||||
batch_token_ids = []
|
||||
batch_logprobs = []
|
||||
max_seqlen = -1
|
||||
for x in results:
|
||||
max_seqlen = max(max_seqlen, len(x.output_ids))
|
||||
batch_token_ids.append(x.output_ids)
|
||||
batch_logprobs.append(x.output_logprobs)
|
||||
|
||||
# To be consistent with our internal implementation,
|
||||
# we should pad generated tokens and logprobs
|
||||
batch_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_seqlen - len(t))
|
||||
for t in batch_token_ids
|
||||
]
|
||||
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
|
||||
|
||||
return (
|
||||
torch.tensor(
|
||||
batch_token_ids, dtype=torch.long, device=constants.current_device()
|
||||
),
|
||||
torch.tensor(
|
||||
batch_logprobs, dtype=torch.float32, device=constants.current_device()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_: data_api.SequenceSample,
|
||||
mb_spec: data_api.MicroBatchSpec,
|
||||
tokenizer: transformers.PreTrainedTokenizerFast,
|
||||
gconfig: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
|
||||
if gconfig.min_new_tokens != 0:
|
||||
raise RuntimeError(
|
||||
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
|
||||
"because we force to skip_tokenizer_init."
|
||||
)
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return None, None, None
|
||||
|
||||
results = asyncio.run(
|
||||
self.async_generate(
|
||||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
tokenizer=tokenizer,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
)
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return results
|
||||
|
||||
def update_weights_from_disk(self, path):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
return
|
||||
|
||||
async def _fn():
|
||||
async with SGLangAPIClient(
|
||||
generate_url=self.api_urls["generate"],
|
||||
update_weights_url=self.api_urls["update_weights_from_disk"],
|
||||
) as client:
|
||||
await client.async_update_weights_from_disk(path)
|
||||
|
||||
asyncio.run(_fn())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||
model_path: str = ""
|
||||
dtype: str = "float16"
|
||||
|
||||
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
|
||||
if constants.pipe_parallel_world_size() != 1:
|
||||
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
|
||||
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
|
||||
raise RuntimeError(
|
||||
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
|
||||
)
|
||||
|
||||
additional_args = dataclasses.asdict(self)
|
||||
additional_args.pop("hybrid_train")
|
||||
additional_args["random_seed"] = seeding.get_seed()
|
||||
|
||||
# For simplicity, we let all DP ranks have different ports.
|
||||
ports = [None for _ in range(constants.data_parallel_world_size())]
|
||||
while any(port is None for port in ports) or len(set(ports)) != len(ports):
|
||||
dist.all_gather_object(
|
||||
ports, network.find_free_port(), group=constants.data_parallel_group()
|
||||
)
|
||||
additional_args["port"] = ports[constants.data_parallel_rank()]
|
||||
|
||||
server_args_dict = dict(
|
||||
host="localhost",
|
||||
# Model and tokenizer
|
||||
tokenizer_path=self.model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
kv_cache_dtype="auto",
|
||||
device="cuda",
|
||||
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=True,
|
||||
# Other runtime options
|
||||
tp_size=constants.model_parallel_world_size(),
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
||||
file_storage_pth=os.path.join(
|
||||
constants.SGLANG_CACHE_PATH,
|
||||
f"sglang_storage{constants.data_parallel_rank()}",
|
||||
),
|
||||
# Data parallelism
|
||||
dp_size=1, # TODO: check whether we require SGLang dp
|
||||
load_balance_method="round_robin",
|
||||
# Expert parallelism
|
||||
ep_size=1, # TODO: check
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr=None,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
**additional_args,
|
||||
)
|
||||
|
||||
model.module = SGLangGenerationEngine(
|
||||
server_args_dict,
|
||||
hybrid_train=self.hybrid_train,
|
||||
)
|
||||
model.backend_name = "sglang"
|
||||
return model
|
||||
|
||||
def load(self, model: Model, load_dir: str):
|
||||
model.module.update_weights_from_disk(load_dir)
|
||||
|
||||
|
||||
register_backend("sglang", SGLangGenerationBackend)
|
|
@ -0,0 +1,249 @@
|
|||
import math
|
||||
|
||||
from realhf.base import constants, logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Adopted from Megatron-LM/megatron/training/optimizer_param_scheduler.py
|
||||
class OptimizerParamScheduler(object):
|
||||
"""Anneals learning rate and weight decay.
|
||||
|
||||
Adopted from Megatron-LM. This class is not included in
|
||||
megatron.core, so we have to copy-paste it here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
init_lr,
|
||||
max_lr,
|
||||
min_lr,
|
||||
lr_warmup_steps,
|
||||
lr_decay_steps,
|
||||
lr_decay_style,
|
||||
start_wd,
|
||||
end_wd,
|
||||
wd_incr_steps,
|
||||
wd_incr_style,
|
||||
use_checkpoint_opt_param_scheduler=True,
|
||||
override_opt_param_scheduler=False,
|
||||
):
|
||||
|
||||
# Class values.
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.max_lr = float(max_lr)
|
||||
self.min_lr = min_lr
|
||||
assert self.min_lr >= 0.0
|
||||
assert self.max_lr >= self.min_lr
|
||||
assert self.init_lr <= self.max_lr
|
||||
|
||||
self.lr_warmup_steps = lr_warmup_steps
|
||||
self.num_steps = 0
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
assert self.lr_decay_steps > 0
|
||||
assert self.lr_warmup_steps < self.lr_decay_steps
|
||||
|
||||
self.lr_decay_style = lr_decay_style
|
||||
|
||||
self.start_wd = start_wd
|
||||
self.end_wd = end_wd
|
||||
assert self.start_wd >= 0.0
|
||||
assert self.end_wd >= self.start_wd
|
||||
self.wd_incr_steps = wd_incr_steps
|
||||
self.wd_incr_style = wd_incr_style
|
||||
|
||||
self.override_opt_param_scheduler = override_opt_param_scheduler
|
||||
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
|
||||
if self.override_opt_param_scheduler:
|
||||
assert not self.use_checkpoint_opt_param_scheduler, (
|
||||
"both override and " "use-checkpoint are set."
|
||||
)
|
||||
|
||||
# Set the learning rate
|
||||
self.step(0)
|
||||
self.log_rank_0("> learning rate decay style: {}".format(self.lr_decay_style))
|
||||
|
||||
def log_rank_0(self, msg):
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(msg)
|
||||
|
||||
def get_wd(self):
|
||||
"""Weight decay incr functions."""
|
||||
if self.num_steps > self.wd_incr_steps:
|
||||
return self.end_wd
|
||||
|
||||
if self.wd_incr_style == "constant":
|
||||
assert self.start_wd == self.end_wd
|
||||
return self.end_wd
|
||||
|
||||
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
|
||||
assert incr_ratio >= 0.0
|
||||
assert incr_ratio <= 1.0
|
||||
delta_wd = self.end_wd - self.start_wd
|
||||
|
||||
if self.wd_incr_style == "linear":
|
||||
coeff = incr_ratio
|
||||
elif self.wd_incr_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
|
||||
else:
|
||||
raise Exception(
|
||||
"{} weight decay increment style is not supported.".format(
|
||||
self.wd_incr_style
|
||||
)
|
||||
)
|
||||
|
||||
return self.start_wd + coeff * delta_wd
|
||||
|
||||
def get_lr(self, param_group):
|
||||
"""Learning rate decay functions from:
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
|
||||
max_lr = param_group.get("max_lr", self.max_lr)
|
||||
min_lr = param_group.get("min_lr", self.min_lr)
|
||||
|
||||
# Use linear warmup for the initial part.
|
||||
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
|
||||
return self.init_lr + (
|
||||
(max_lr - self.init_lr)
|
||||
* float(self.num_steps)
|
||||
/ float(self.lr_warmup_steps)
|
||||
)
|
||||
|
||||
# If the learning rate is constant, just return the initial value.
|
||||
if self.lr_decay_style == "constant":
|
||||
return max_lr
|
||||
|
||||
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
|
||||
if self.num_steps > self.lr_decay_steps:
|
||||
return min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
if self.lr_decay_style == "inverse-square-root":
|
||||
warmup_steps = max(self.lr_warmup_steps, 1)
|
||||
num_steps = max(self.num_steps, 1)
|
||||
lr = max_lr * warmup_steps**0.5 / (num_steps**0.5)
|
||||
return max(min_lr, lr)
|
||||
|
||||
num_steps_ = self.num_steps - self.lr_warmup_steps
|
||||
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = max_lr - min_lr
|
||||
|
||||
if self.lr_decay_style == "linear":
|
||||
coeff = 1.0 - decay_ratio
|
||||
elif self.lr_decay_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
else:
|
||||
raise Exception(
|
||||
"{} decay style is not supported.".format(self.lr_decay_style)
|
||||
)
|
||||
|
||||
return min_lr + coeff * delta_lr
|
||||
|
||||
def step(self, increment):
|
||||
"""Set lr for all parameters groups."""
|
||||
self.num_steps += increment
|
||||
new_wd = self.get_wd()
|
||||
for param_group in self.optimizer.param_groups:
|
||||
new_lr = self.get_lr(param_group)
|
||||
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
|
||||
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"max_lr": self.max_lr,
|
||||
"lr_warmup_steps": self.lr_warmup_steps,
|
||||
"num_steps": self.num_steps,
|
||||
"lr_decay_style": self.lr_decay_style,
|
||||
"lr_decay_steps": self.lr_decay_steps,
|
||||
"min_lr": self.min_lr,
|
||||
"start_wd": self.start_wd,
|
||||
"end_wd": self.end_wd,
|
||||
"wd_incr_style": self.wd_incr_style,
|
||||
"wd_incr_steps": self.wd_incr_steps,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def _check_and_set(self, cls_value, sd_value, name):
|
||||
"""Auxiliary function for checking the values in the checkpoint and
|
||||
setting them."""
|
||||
if self.override_opt_param_scheduler:
|
||||
self.log_rank_0(" > overriding {} value to {}".format(name, cls_value))
|
||||
return cls_value
|
||||
|
||||
if not self.use_checkpoint_opt_param_scheduler:
|
||||
assert cls_value == sd_value, (
|
||||
f"OptimizerParamScheduler: class input value {cls_value} and checkpoint"
|
||||
f"value {sd_value} for {name} do not match"
|
||||
)
|
||||
self.log_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
|
||||
return sd_value
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
|
||||
if "start_lr" in sd:
|
||||
max_lr_ = sd["start_lr"]
|
||||
else:
|
||||
max_lr_ = sd["max_lr"]
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
|
||||
|
||||
self.min_lr = self._check_and_set(
|
||||
self.min_lr, sd["min_lr"], "minimum learning rate"
|
||||
)
|
||||
|
||||
if "warmup_iter" in sd:
|
||||
lr_warmup_steps_ = sd["warmup_iter"]
|
||||
elif "warmup_steps" in sd:
|
||||
lr_warmup_steps_ = sd["warmup_steps"]
|
||||
else:
|
||||
lr_warmup_steps_ = sd["lr_warmup_steps"]
|
||||
self.lr_warmup_steps = self._check_and_set(
|
||||
self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations"
|
||||
)
|
||||
|
||||
if "end_iter" in sd:
|
||||
lr_decay_steps_ = sd["end_iter"]
|
||||
elif "decay_steps" in sd:
|
||||
lr_decay_steps_ = sd["decay_steps"]
|
||||
else:
|
||||
lr_decay_steps_ = sd["lr_decay_steps"]
|
||||
self.lr_decay_steps = self._check_and_set(
|
||||
self.lr_decay_steps, lr_decay_steps_, "total number of iterations"
|
||||
)
|
||||
|
||||
if "decay_style" in sd:
|
||||
lr_decay_style_ = sd["decay_style"]
|
||||
else:
|
||||
lr_decay_style_ = sd["lr_decay_style"]
|
||||
self.lr_decay_style = self._check_and_set(
|
||||
self.lr_decay_style, lr_decay_style_, "learning rate decay style"
|
||||
)
|
||||
|
||||
if "num_iters" in sd:
|
||||
num_steps = sd["num_iters"]
|
||||
else:
|
||||
num_steps = sd["num_steps"]
|
||||
self.step(increment=num_steps)
|
||||
|
||||
if "start_wd" in sd:
|
||||
self.start_wd = self._check_and_set(
|
||||
self.start_wd, sd["start_wd"], "start weight decay"
|
||||
)
|
||||
self.end_wd = self._check_and_set(
|
||||
self.end_wd, sd["end_wd"], "end weight decay"
|
||||
)
|
||||
self.wd_incr_steps = self._check_and_set(
|
||||
self.wd_incr_steps,
|
||||
sd["wd_incr_steps"],
|
||||
"total number of weight decay iterations",
|
||||
)
|
||||
self.wd_incr_style = self._check_and_set(
|
||||
self.wd_incr_style,
|
||||
sd["wd_incr_style"],
|
||||
"weight decay incr style",
|
||||
)
|
|
@ -136,6 +136,8 @@ def setup_global_comm(
|
|||
for model_name, ranks in mw_ranks.items():
|
||||
model_groups[model_name] = topology.new_or_get_group(ranks, backend=backend)
|
||||
constants.set_parallelism_group(model_name, model_groups[model_name], ranks)
|
||||
cpu_group = topology.new_or_get_group(ranks, backend="gloo")
|
||||
constants.set_cpu_parallelism_group(model_name, cpu_group)
|
||||
|
||||
self_group = None
|
||||
for i in range(world_size):
|
||||
|
|
|
@ -62,6 +62,9 @@ class ParamReallocInfo:
|
|||
param_realloc_model_group: Dict[
|
||||
ParamReallocModelPair, torch.distributed.ProcessGroup
|
||||
]
|
||||
param_realloc_model_cpu_group: Dict[
|
||||
ParamReallocModelPair, torch.distributed.ProcessGroup
|
||||
]
|
||||
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup]
|
||||
param_realloc_src_ranks: Dict[ParamReallocPair, int]
|
||||
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]]
|
||||
|
@ -270,6 +273,7 @@ def setup_param_realloc(
|
|||
param_realloc_src_ranks = {}
|
||||
param_realloc_dst_ranks = {}
|
||||
param_realloc_model_group = {}
|
||||
param_realloc_model_cpu_group = {}
|
||||
if param_realloc_pairs is not None:
|
||||
for src, dst in param_realloc_pairs:
|
||||
_create_param_realloc_groups(
|
||||
|
@ -296,11 +300,15 @@ def setup_param_realloc(
|
|||
param_realloc_model_group[ParamReallocModelPair(src, dst)] = (
|
||||
topology.new_or_get_group(list(sorted(pair_mw_ranks)))
|
||||
)
|
||||
param_realloc_model_cpu_group[ParamReallocModelPair(src, dst)] = (
|
||||
topology.new_or_get_group(list(sorted(pair_mw_ranks)), backend="gloo")
|
||||
)
|
||||
return ParamReallocInfo(
|
||||
param_realloc_groups=param_realloc_groups,
|
||||
param_realloc_src_ranks=param_realloc_src_ranks,
|
||||
param_realloc_dst_ranks=param_realloc_dst_ranks,
|
||||
param_realloc_model_group=param_realloc_model_group,
|
||||
param_realloc_model_cpu_group=param_realloc_model_cpu_group,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -444,13 +444,13 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["packed_ref_logprobs"],
|
||||
keys=["logprobs"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||
data=dict(packed_ref_logprobs=logprobs),
|
||||
dtypes=dict(logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(logprobs=()),
|
||||
data=dict(logprobs=logprobs),
|
||||
seqlens=dict(
|
||||
packed_ref_logprobs=[
|
||||
logprobs=[
|
||||
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
||||
]
|
||||
),
|
||||
|
|
|
@ -11,21 +11,17 @@ from typing import Dict, Literal, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
|
||||
import realhf.base.logging as logging
|
||||
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
|
||||
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
from realhf.impl.model.utils.functional import (
|
||||
gather_packed_shifted_log_probs,
|
||||
masked_normalization,
|
||||
)
|
||||
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
|
||||
|
||||
logger = logging.getLogger("RefRwInterface")
|
||||
|
||||
|
|
|
@ -122,6 +122,7 @@ class ReaLModel(nn.Module):
|
|||
config: model_api.ReaLModelConfig,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
hf_model_family: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dtype is None:
|
||||
|
@ -173,6 +174,14 @@ class ReaLModel(nn.Module):
|
|||
)
|
||||
self.contiguous_param = None
|
||||
|
||||
self.hf_model_family = hf_model_family
|
||||
|
||||
def save_to_hf(self, tokenizer, save_dir):
|
||||
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
|
||||
|
||||
def load_from_hf(self, load_dir):
|
||||
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
|
||||
|
||||
@property
|
||||
def pre_process(self):
|
||||
# A workaround to make Megatron-LM backend happy.
|
||||
|
@ -918,12 +927,7 @@ def make_real_model(
|
|||
model_path=model_path,
|
||||
is_critic=is_critic or init_critic_from_actor,
|
||||
)
|
||||
m = ReaLModel(mconfig, dtype=dtype, device=device)
|
||||
|
||||
# Since we load from `hf_model_family`, we should save to `hf_model_family`.
|
||||
# The following line creates a convinent function to save the model.
|
||||
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{hf_model_family}"))
|
||||
setattr(ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{hf_model_family}"))
|
||||
m = ReaLModel(mconfig, dtype=dtype, device=device, hf_model_family=hf_model_family)
|
||||
|
||||
if not init_from_scratch:
|
||||
m._instantiation_hooks.append(
|
||||
|
|
|
@ -166,6 +166,27 @@ def build_leave_one_indices(
|
|||
)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
):
|
||||
"""Gather log probs from logits and labels.
|
||||
|
||||
Args:
|
||||
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
each sequence is not used.
|
||||
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
The first value at the beginning of each sequence has no corresponding log prob.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: Log probability with shape [tot_seqlen - #seqs].
|
||||
"""
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
def gather_packed_shifted_log_probs(
|
||||
logits: torch.FloatTensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
|
@ -174,11 +195,11 @@ def gather_packed_shifted_log_probs(
|
|||
"""Gather log probs from packed input_ids and logits.
|
||||
|
||||
Args:
|
||||
logits_ (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
logits (torch.FloatTensor): Shape [tot_seqlen]. The final value at the end of
|
||||
each sequence is not used.
|
||||
cu_seqlens (torch.Tensor): Shape [#seqs + 1]. Indices marking the start
|
||||
and end of each sequences.
|
||||
labels_ (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
and end of each sequence.
|
||||
labels (torch.LongTensor): Labels or input_ids with shape [tot_seqlen].
|
||||
The first value at the beginning of each sequence has no corresponding log prob.
|
||||
|
||||
Returns:
|
||||
|
@ -202,8 +223,7 @@ def gather_packed_shifted_log_probs(
|
|||
# for i in range(cu_seqlens.shape[0] - 1)
|
||||
# ])
|
||||
# shift labels one step to the left and pad it to match the shape of logits
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
log_probs_labels = gather_logprobs(logits, labels)
|
||||
log_probs_labels = log_probs_labels[leave_one_indices]
|
||||
assert log_probs_labels.shape[0] == logits_shape[0] - cu_seqlens.shape[0] + 1, (
|
||||
log_probs_labels.shape,
|
||||
|
|
|
@ -127,9 +127,10 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
)
|
||||
wrap_cmd = "singularity exec "
|
||||
if cluster_spec.name == "na132":
|
||||
wrap_cmd += "--pid --no-home --writable-tmpfs "
|
||||
wrap_cmd += "--pid "
|
||||
if cluster_spec.gpu_type == "tesla":
|
||||
wrap_cmd += "--nv "
|
||||
wrap_cmd += "--no-home --writable-tmpfs "
|
||||
if len(launch_info.env_vars) > 0:
|
||||
wrap_cmd += f"{' '.join([f'--env {k}={v}' for k, v in launch_info.env_vars.items()])} "
|
||||
if len(launch_info.container_mounts) > 0:
|
||||
|
|
|
@ -621,6 +621,9 @@ class RayController:
|
|||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.distributed as dist
|
|||
|
||||
from realhf import SequenceSample
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.base import constants
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.topology import ProcessTopology, new_or_get_group
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
from realhf.system.redistributor import RedistribStep
|
||||
|
@ -21,6 +21,8 @@ BCAST_GROUPS = {}
|
|||
GATHER_GROUPS = {}
|
||||
SCATTER_GROUPS = {}
|
||||
|
||||
logger = logging.getLogger("data_manager", "system")
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
|
@ -325,8 +327,21 @@ class DataManager:
|
|||
)
|
||||
|
||||
if dist.get_rank() == step.root:
|
||||
scatter_list = []
|
||||
for ids in step.ids:
|
||||
# Scatter destinations include all DP, TP, and PP ranks
|
||||
# and data is duplicated among TP/PP groups
|
||||
# We allocate new memory for DP ranks, but use the same pointer
|
||||
# for all TP and PP ranks to save memory.
|
||||
scatter_clusters = []
|
||||
for idx, ids in enumerate(step.ids):
|
||||
for _ids, idx_list in scatter_clusters:
|
||||
if set(ids) == set(_ids):
|
||||
idx_list.append(idx)
|
||||
break
|
||||
else:
|
||||
scatter_clusters.append((ids, [idx]))
|
||||
scatter_list = [None for _ in range(len(step.ids))]
|
||||
before_pad = []
|
||||
for ids, idx_list in scatter_clusters:
|
||||
for i in ids:
|
||||
self.storage[i].to_device(constants.current_device())
|
||||
samples = [self.storage[i] for i in ids]
|
||||
|
@ -337,11 +352,17 @@ class DataManager:
|
|||
for key in step.keys
|
||||
]
|
||||
)
|
||||
scatter_list.append(data)
|
||||
maxlen = max([x.shape[0] for x in scatter_list])
|
||||
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
|
||||
if step.root not in step.dsts:
|
||||
before_pad.append(data)
|
||||
|
||||
maxlen = max([x.shape[0] for x in before_pad])
|
||||
after_pad = [self._pad_data(x, maxlen) for x in before_pad]
|
||||
for (ids, idx_list), data in zip(scatter_clusters, after_pad):
|
||||
for idx in idx_list:
|
||||
scatter_list[idx] = data
|
||||
|
||||
assert all([torch.is_tensor(t) for t in scatter_list])
|
||||
|
||||
if step.root not in step.dsts:
|
||||
idx = bisect.bisect(step.dsts, step.root)
|
||||
scatter_list.insert(idx, buf)
|
||||
else:
|
||||
|
|
|
@ -85,7 +85,11 @@ class MasterWorker(worker_base.Worker):
|
|||
and config.exp_ctrl.ckpt_freq_steps is None
|
||||
and config.exp_ctrl.ckpt_freq_secs is None
|
||||
):
|
||||
self.__ckpt_ctl = self.__save_ctl
|
||||
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.save_freq_epochs,
|
||||
freq_step=config.exp_ctrl.save_freq_steps,
|
||||
freq_sec=config.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
else:
|
||||
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
|
||||
|
|
|
@ -24,7 +24,6 @@ import numpy as np
|
|||
import pynvml
|
||||
import tabulate
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
|
@ -36,6 +35,8 @@ from realhf.base import (
|
|||
constants,
|
||||
gpu_utils,
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
network,
|
||||
recover,
|
||||
seeding,
|
||||
|
@ -375,7 +376,8 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
f"dataset_indices_{self._dp_rank}.npy",
|
||||
"dataset_indices",
|
||||
f"{self._dp_rank}.npy",
|
||||
)
|
||||
if os.path.exists(dataset_indices_path):
|
||||
indices = np.load(dataset_indices_path).tolist()
|
||||
|
@ -456,8 +458,8 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__request_cache = {}
|
||||
self.__ack_cache = {}
|
||||
|
||||
self.__request_queue = queue.Queue(maxsize=8)
|
||||
self.__reply_queue = queue.Queue(maxsize=8)
|
||||
self.__request_queue = queue.Queue(maxsize=10240)
|
||||
self.__reply_queue = queue.Queue(maxsize=10240)
|
||||
self.__request_sample_size = dict()
|
||||
|
||||
self.__compute_input_queues = {
|
||||
|
@ -469,6 +471,13 @@ class ModelWorker(worker_base.Worker):
|
|||
for model_name in self.__models.keys()
|
||||
}
|
||||
|
||||
# By intention, must be smaller than -1.
|
||||
self._last_param_realloc_step = -100
|
||||
if self.__recover_run:
|
||||
self._last_param_realloc_step = (
|
||||
self.__recover_info.last_step_info.global_step
|
||||
)
|
||||
|
||||
def __handle_one_rpc_hook(self, hook: str, hook_data: Any):
|
||||
ret = None
|
||||
|
||||
|
@ -580,8 +589,10 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
f"dataset_indices_{dp_rank}.npy",
|
||||
"dataset_indices",
|
||||
f"{dp_rank}.npy",
|
||||
)
|
||||
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
|
||||
if hasattr(self.__dataset, "filter") and os.path.exists(
|
||||
eval_scores_path
|
||||
):
|
||||
|
@ -625,9 +636,6 @@ class ModelWorker(worker_base.Worker):
|
|||
res = data_api.DataBatchMeta(
|
||||
dp_rank=dp_rank,
|
||||
meta_sample=meta_sample,
|
||||
is_final_batch=(
|
||||
self.__dataset_batch_counter == len(self.__dataloader) - 1
|
||||
),
|
||||
)
|
||||
elif request.handle_name == "spec":
|
||||
# Raw dataset without filtering.
|
||||
|
@ -737,6 +745,31 @@ class ModelWorker(worker_base.Worker):
|
|||
assert isinstance(res, dict), res
|
||||
res.update({f"eval_{k}": v for k, v in ret.items()})
|
||||
|
||||
# update param realloc step after handling post hooks
|
||||
if request.handle_name == "train_step":
|
||||
self._last_param_realloc_step = max(self._last_param_realloc_step + 1, 1)
|
||||
realloc_dir = os.path.join(
|
||||
constants.PARAM_REALLOC_PATH,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
model_name.role,
|
||||
)
|
||||
save_meta = dict(
|
||||
model_name=model_name,
|
||||
save_backend=False,
|
||||
save_dir=realloc_dir,
|
||||
)
|
||||
self.__save_model(save_meta)
|
||||
name = names.model_version(
|
||||
self.__experiment_name,
|
||||
self.__trial_name,
|
||||
model_name.role,
|
||||
)
|
||||
with constants.model_scope(model_name):
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
name_resolve.add_subentry(name, str(self._last_param_realloc_step))
|
||||
|
||||
self.__reply_queue.put_nowait((request, res))
|
||||
sample_count = data.bs if isinstance(data, data_api.SequenceSample) else 1
|
||||
self.__request_sample_size[request.request_id] = sample_count
|
||||
|
@ -761,7 +794,7 @@ class ModelWorker(worker_base.Worker):
|
|||
or self.__enable_memory_dump
|
||||
):
|
||||
torch.cuda.synchronize()
|
||||
torch.distributed.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
# pfer can be a null context if enable_profiler is False
|
||||
pfer = get_pytorch_profiler(
|
||||
kernel_only=False, enabled=self.__enable_profiler
|
||||
|
@ -780,7 +813,7 @@ class ModelWorker(worker_base.Worker):
|
|||
or self.__enable_memory_dump
|
||||
):
|
||||
pfer.__exit__(None, None, None)
|
||||
torch.distributed.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
torch.cuda.synchronize()
|
||||
tok = time.perf_counter()
|
||||
rpc_time = tok - tik
|
||||
|
@ -800,6 +833,17 @@ class ModelWorker(worker_base.Worker):
|
|||
else:
|
||||
self.__performance_recorder["time"].append(rpc_time)
|
||||
|
||||
with open(
|
||||
os.path.join(
|
||||
self._get_setup_logdir("performance"),
|
||||
f"rpc-mw{self.__worker_index}.txt",
|
||||
),
|
||||
"a",
|
||||
) as f:
|
||||
f.write(
|
||||
f"rpc: {rpc.name} rank: {dist.get_rank()} time: {rpc_time}\n"
|
||||
)
|
||||
|
||||
if self.__enable_profiler:
|
||||
if self._dp_rank == 0 and self._is_dp_head:
|
||||
blogger.info(
|
||||
|
@ -902,7 +946,7 @@ class ModelWorker(worker_base.Worker):
|
|||
eval_scores.update(scores)
|
||||
|
||||
res.metadata.pop("scores")
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
if len(eval_scores) > 0 and self._dp_rank == 0 and self._is_dp_head:
|
||||
with open(
|
||||
eval_scores_path,
|
||||
|
@ -938,7 +982,7 @@ class ModelWorker(worker_base.Worker):
|
|||
self._clear_memory()
|
||||
if constants.use_cuda():
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
return res
|
||||
|
||||
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
|
||||
|
@ -980,7 +1024,7 @@ class ModelWorker(worker_base.Worker):
|
|||
with constants.model_scope(from_model_name):
|
||||
from_model_ranks = constants.parallelism_group_ranks()
|
||||
if not param_realloc_comm.is_trainable(from_model_name):
|
||||
if torch.distributed.get_rank() not in from_model_ranks:
|
||||
if dist.get_rank() not in from_model_ranks:
|
||||
return
|
||||
if not isinstance(self.__unwrapped_models[from_model_name], ReaLModel):
|
||||
# We can only release the memory of ReaLModel,
|
||||
|
@ -1006,7 +1050,7 @@ class ModelWorker(worker_base.Worker):
|
|||
save_dir=realloc_dir,
|
||||
)
|
||||
self.__save_model(save_meta)
|
||||
g = self.__param_realloc_info.param_realloc_model_group[
|
||||
g = self.__param_realloc_info.param_realloc_model_cpu_group[
|
||||
param_realloc_comm.ParamReallocModelPair(from_model_name, to_model_name)
|
||||
]
|
||||
dist.barrier(group=g)
|
||||
|
@ -1018,7 +1062,7 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__load_model(load_meta)
|
||||
# Remove the reallocated checkpoint.
|
||||
with constants.model_scope(to_model_name):
|
||||
dist.barrier(constants.parallelism_group())
|
||||
dist.barrier(constants.cpu_parallelism_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
shutil.rmtree(realloc_dir, ignore_errors=True)
|
||||
os.makedirs(realloc_dir, exist_ok=True)
|
||||
|
@ -1086,7 +1130,7 @@ class ModelWorker(worker_base.Worker):
|
|||
).is_symlink():
|
||||
os.unlink(save_root / fn)
|
||||
shutil.rmtree(save_dir, ignore_errors=True)
|
||||
dist.barrier(constants.parallelism_group())
|
||||
dist.barrier(constants.cpu_parallelism_group())
|
||||
self._interface.save(self._model, save_dir)
|
||||
# The `save` method of the interface may be empty.
|
||||
# We only save the backend state if the parameters have been indeed saved.
|
||||
|
@ -1153,11 +1197,12 @@ class ModelWorker(worker_base.Worker):
|
|||
@cuda_tmark("post_response", CUDATimeMarkType.misc)
|
||||
def maybe_post_responses(self):
|
||||
ready_to_post = []
|
||||
try:
|
||||
request, res = self.__reply_queue.get_nowait()
|
||||
ready_to_post.append((request, res))
|
||||
except queue.Empty:
|
||||
pass
|
||||
while True:
|
||||
try:
|
||||
request, res = self.__reply_queue.get_nowait()
|
||||
ready_to_post.append((request, res))
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
batch_size = sample_size = 0
|
||||
for request, res in ready_to_post:
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class Payload:
|
|||
syn_reply_id: uuid.UUID = None
|
||||
ack_reply_id: uuid.UUID = None
|
||||
|
||||
no_syn: bool = False
|
||||
no_syn: bool = True
|
||||
|
||||
send_time: float = None
|
||||
|
||||
|
@ -164,7 +164,7 @@ class NameResolvingRequestClient:
|
|||
datas: List[Any] | None = None,
|
||||
payloads: List[Payload] | None = None,
|
||||
verbose: bool = True,
|
||||
no_syn: bool = False,
|
||||
no_syn: bool = True,
|
||||
) -> List[uuid.UUID]:
|
||||
"""Send requests of type `handle_type` to all `handlers` with
|
||||
corresponding `data`.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import os
|
||||
|
@ -547,6 +547,18 @@ class Worker:
|
|||
"""Implemented by sub-classes."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self.__running
|
||||
|
||||
@property
|
||||
def exiting(self):
|
||||
return self.__exiting
|
||||
|
||||
@property
|
||||
def is_configured(self):
|
||||
return self.__is_configured
|
||||
|
||||
def configure(
|
||||
self,
|
||||
worker_info: system_api.WorkerInformation,
|
||||
|
@ -625,10 +637,11 @@ class Worker:
|
|||
def _exit_hook(self, exit_status: WorkerServerStatus):
|
||||
logger.warning(f"Exit with {exit_status}, hook not implemented, pass.")
|
||||
|
||||
def exit(self):
|
||||
def exit(self, err: bool = False):
|
||||
self.logger.info("Exiting worker")
|
||||
self._exit_hook(WorkerServerStatus.COMPLETED)
|
||||
self.__set_status(WorkerServerStatus.COMPLETED)
|
||||
status = WorkerServerStatus.ERROR if err else WorkerServerStatus.COMPLETED
|
||||
self._exit_hook(status)
|
||||
self.__set_status(status)
|
||||
self.__exiting = True
|
||||
|
||||
def interrupt(self):
|
||||
|
@ -681,8 +694,7 @@ class Worker:
|
|||
logger.error(f"Worker encountered error {e}", exc_info=True)
|
||||
if isinstance(e, WorkerException):
|
||||
raise e
|
||||
self.__set_status(WorkerServerStatus.ERROR)
|
||||
self._exit_hook(WorkerServerStatus.ERROR)
|
||||
self.exit(err=True)
|
||||
raise e
|
||||
|
||||
def __host_key(self, key: str):
|
||||
|
@ -694,6 +706,32 @@ class Worker:
|
|||
name_resolve.watch_names(keys, call_back=self.exit)
|
||||
|
||||
|
||||
class AsyncWorker(Worker):
|
||||
async def _poll_async(self) -> PollResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def run_async(self):
|
||||
self.logger.debug("Running worker now")
|
||||
try:
|
||||
while not self.exiting:
|
||||
await asyncio.sleep(0.0)
|
||||
self._server.handle_requests()
|
||||
if not self.running:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
if not self.is_configured:
|
||||
raise RuntimeError("Worker is not configured")
|
||||
r = await self._poll_async()
|
||||
except KeyboardInterrupt:
|
||||
self.exit()
|
||||
except Exception as e:
|
||||
logger.error(f"Worker encountered error {e}", exc_info=True)
|
||||
if isinstance(e, WorkerException):
|
||||
raise e
|
||||
self.exit(err=True)
|
||||
raise e
|
||||
|
||||
|
||||
class MappingThread:
|
||||
"""Wrapped of a mapping thread.
|
||||
|
||||
|
|
|
@ -17,11 +17,7 @@ import torch.distributed as dist
|
|||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, testing, topology
|
||||
from realhf.base.testing import (
|
||||
LocalMultiProcessTest,
|
||||
PipeDataModelParallelTopology,
|
||||
init_global_constants,
|
||||
)
|
||||
from realhf.base.testing import LocalMultiProcessTest, init_global_constants
|
||||
from realhf.system.data_manager import DataManager
|
||||
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
|
||||
|
||||
|
@ -143,7 +139,7 @@ def _test_data_transfer(
|
|||
):
|
||||
|
||||
from_model_name = ModelName("data_transfer_test", 0)
|
||||
from_topo = PipeDataModelParallelTopology(
|
||||
from_topo = topology.PipeDataModelParallelTopology(
|
||||
num_pp=from_pp_dp_mp[0],
|
||||
num_mp=from_pp_dp_mp[-1],
|
||||
num_dp=from_pp_dp_mp[1],
|
||||
|
@ -152,7 +148,7 @@ def _test_data_transfer(
|
|||
gradient_accumulation_fusion=True,
|
||||
)
|
||||
to_model_name = ModelName("data_transfer_test", 1)
|
||||
to_topo = PipeDataModelParallelTopology(
|
||||
to_topo = topology.PipeDataModelParallelTopology(
|
||||
num_pp=to_pp_dp_mp[0],
|
||||
num_mp=to_pp_dp_mp[-1],
|
||||
num_dp=to_pp_dp_mp[1],
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import functools
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import json
|
||||
import time
|
||||
from typing import *
|
||||
|
||||
|
@ -26,9 +26,9 @@ from realhf.api.core.model_api import ReaLModelConfig
|
|||
from realhf.base import constants, logging
|
||||
from realhf.base.network import find_free_port
|
||||
from realhf.base.testing import (
|
||||
init_global_constants,
|
||||
_DEFAULT_EXPR_NAME,
|
||||
_DEFAULT_TRIAL_NAME,
|
||||
init_global_constants,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("test async ref-rew")
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import functools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from realhf.api.core import data_api, model_api
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.data_api import MicroBatchSpec
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.testing import init_global_constants
|
||||
|
||||
logger = logging.getLogger("test sglang backend")
|
||||
|
||||
|
||||
def check_sequences_consistency(
|
||||
batched_seq1: torch.LongTensor, batched_seq2: torch.LongTensor
|
||||
):
|
||||
matched_tokens = 0
|
||||
matched_seqs = 0
|
||||
total_tokens = 0
|
||||
assert len(batched_seq1) == len(batched_seq2)
|
||||
for i in range(len(batched_seq1)):
|
||||
a = batched_seq1[i]
|
||||
b = batched_seq2[i]
|
||||
assert torch.is_tensor(a) and torch.is_tensor(b)
|
||||
assert a.dim() == 1 and b.dim() == 1, (a.shape, b.shape)
|
||||
gen_len = a.shape[0] if a.shape[0] < b.shape[0] else b.shape[0]
|
||||
b = b[:gen_len]
|
||||
a = a[:gen_len]
|
||||
for j in range(gen_len):
|
||||
if a[j] != b[j]:
|
||||
logger.info(f"Mismatch at sequence {i} position {j}")
|
||||
break
|
||||
matched_tokens += 1
|
||||
else:
|
||||
matched_seqs += 1
|
||||
total_tokens += gen_len
|
||||
logger.info(
|
||||
f"Matched {matched_seqs}/{len(batched_seq1)} "
|
||||
f"sequences and {matched_tokens}/{total_tokens} tokens"
|
||||
)
|
||||
return (
|
||||
matched_seqs,
|
||||
matched_tokens,
|
||||
float(matched_tokens) / total_tokens,
|
||||
float(matched_seqs) / len(batched_seq1),
|
||||
)
|
||||
|
||||
|
||||
def test_fn(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
path: str,
|
||||
model_family_name: str,
|
||||
dp: int,
|
||||
pp: int,
|
||||
tp: int,
|
||||
):
|
||||
assert not torch.cuda.is_initialized()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
torch.cuda.set_device(0)
|
||||
assert world_size == (
|
||||
dp * pp * tp
|
||||
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
||||
# Initialize distributed environment.
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method="tcp://localhost:7777",
|
||||
)
|
||||
torch.cuda.set_device(0)
|
||||
model_name = ModelName("default", 0)
|
||||
constants.set_experiment_trial_names("slang-test", str(uuid.uuid4()))
|
||||
init_global_constants(
|
||||
num_dp=dp,
|
||||
num_mp=tp,
|
||||
num_pp=pp,
|
||||
sequence_parallel=False,
|
||||
model_name=model_name,
|
||||
max_prompt_len=128,
|
||||
)
|
||||
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
||||
|
||||
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
||||
transformers.AutoConfig.from_pretrained(
|
||||
path,
|
||||
trust_remote_code=True,
|
||||
force_download=True,
|
||||
)
|
||||
)
|
||||
with constants.model_scope(model_name):
|
||||
module = ReaLModel(mconfig, dtype=torch.float16, device="cuda")
|
||||
module._instantiation_hooks.append(
|
||||
lambda: getattr(module, f"from_{model_family_name}")(
|
||||
load_dir=path, init_critic_from_actor=False
|
||||
)
|
||||
)
|
||||
add_helper_functions(module)
|
||||
module.instantiate()
|
||||
module.eval()
|
||||
tokenizer = data_api.load_hf_tokenizer(path)
|
||||
|
||||
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
|
||||
|
||||
backend = SGLangGenerationBackend(model_path=path)
|
||||
model = model_api.Model(
|
||||
name=model_name,
|
||||
module=module,
|
||||
tokenizer=tokenizer,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=1,
|
||||
dataset_size=100,
|
||||
train_batch_size=1,
|
||||
)
|
||||
model = backend.initialize(model, ft_spec)
|
||||
|
||||
gconfig = model_api.GenerationHyperparameters(
|
||||
n=1,
|
||||
max_new_tokens=32,
|
||||
min_new_tokens=0,
|
||||
greedy=True,
|
||||
top_p=1.0,
|
||||
top_k=int(1e8),
|
||||
temperature=1.0,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
bs = 8
|
||||
for i in range(1):
|
||||
seqlens = [torch.randint(5, 10, (1,)).cuda() for _ in range(bs)]
|
||||
|
||||
for s in seqlens:
|
||||
dist.broadcast(s, src=0)
|
||||
seqlens = [int(s) for s in seqlens]
|
||||
|
||||
token_ids = (
|
||||
torch.randint(0, mconfig.vocab_size, (sum(seqlens),)).long().cuda()
|
||||
)
|
||||
dist.broadcast(token_ids, src=0)
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
torch.tensor(seqlens, device="cuda").cumsum(0),
|
||||
(1, 0),
|
||||
).int()
|
||||
|
||||
res = module.generate(
|
||||
tokenizer=tokenizer,
|
||||
packed_input_ids=token_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
gen_tokens1 = res.sequences
|
||||
logprobs1 = res.scores
|
||||
|
||||
x = data_api.SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=list(range(bs)),
|
||||
data=dict(packed_input_ids=token_ids),
|
||||
)
|
||||
gen_tokens2, logprobs2, _ = model.module.generate(
|
||||
input_=x,
|
||||
mb_spec=MicroBatchSpec(),
|
||||
tokenizer=tokenizer,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
if constants.model_parallel_rank() == 0:
|
||||
# The outputs are Nones for tp_rank > 1 in SGLang
|
||||
_, _, token_match_percent, seq_match_percent = (
|
||||
check_sequences_consistency(gen_tokens1, gen_tokens2)
|
||||
)
|
||||
assert token_match_percent > 0.8, token_match_percent
|
||||
assert seq_match_percent > 0.8, seq_match_percent
|
||||
|
||||
print("success")
|
||||
|
||||
# 清理分布式环境
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
world_size = dp * tp
|
||||
procs = [
|
||||
mp.Process(
|
||||
target=test_fn,
|
||||
args=(
|
||||
i,
|
||||
world_size,
|
||||
),
|
||||
kwargs=dict(
|
||||
path=path,
|
||||
model_family_name=model_family_name,
|
||||
dp=dp,
|
||||
pp=1,
|
||||
tp=tp,
|
||||
),
|
||||
)
|
||||
for i in range(world_size)
|
||||
]
|
||||
try:
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
[p.terminate() for p in procs]
|
||||
[p.join() for p in procs]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
model_family_name = "qwen2"
|
||||
# test_fn(
|
||||
# rank=0,
|
||||
# world_size=1,
|
||||
# path=path,
|
||||
# model_family_name=model_family_name,
|
||||
# dp=1,
|
||||
# pp=1,
|
||||
# tp=1,
|
||||
# )
|
||||
test_sglang_consistency(
|
||||
tp=2,
|
||||
dp=2,
|
||||
path=path,
|
||||
model_family_name=model_family_name,
|
||||
)
|
Loading…
Reference in New Issue