mirror of https://github.com/inclusionAI/AReaL
318 lines
9.5 KiB
Python
318 lines
9.5 KiB
Python
import functools
|
|
import gc
|
|
import os
|
|
import pickle
|
|
import json
|
|
import time
|
|
from typing import *
|
|
|
|
import numpy as np
|
|
import pynvml
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import transformers
|
|
from torch.cuda import is_initialized
|
|
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
print(root_dir)
|
|
import sys
|
|
|
|
sys.path.insert(0, root_dir)
|
|
|
|
from realhf.api.core import data_api, dfg, model_api
|
|
from realhf.api.core.config import ModelName
|
|
from realhf.api.core.model_api import ReaLModelConfig
|
|
from realhf.base import constants, logging
|
|
from realhf.base.network import find_free_port
|
|
from realhf.base.testing import (
|
|
init_global_constants,
|
|
_DEFAULT_EXPR_NAME,
|
|
_DEFAULT_TRIAL_NAME,
|
|
)
|
|
|
|
logger = logging.getLogger("test async ref-rew")
|
|
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
|
|
|
|
|
|
def loadJson():
|
|
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
|
|
with open(dataDir, "r") as f:
|
|
if dataDir.endswith(".jsonl"):
|
|
samples = [json.loads(line) for line in f.readlines()]
|
|
else:
|
|
samples = json.load(f)
|
|
|
|
return samples
|
|
|
|
|
|
def _mock_input(batch_size: int, seq_len):
|
|
vocab_size = 100
|
|
torch.manual_seed(1)
|
|
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
|
|
|
samples = loadJson()
|
|
id_list = list(samples.keys())
|
|
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
|
|
|
|
return data_api.SequenceSample.from_default(
|
|
seqlens=[seq_len for _ in range(seqs.shape[0])],
|
|
ids=[id_list[i] for i in range(seqs.shape[0])],
|
|
data=dict(
|
|
packed_input_ids=seqs.view(-1),
|
|
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
|
|
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
|
|
),
|
|
)
|
|
|
|
|
|
def funcion_call(
|
|
rpc_name: str,
|
|
rank: int,
|
|
world_size: int,
|
|
model_path: str,
|
|
model_family_name: str,
|
|
dp: int,
|
|
pp: int,
|
|
tp: int,
|
|
interface_type: dfg.ModelInterfaceType,
|
|
interface_impl: dfg.ModelInterfaceAbstraction,
|
|
batch_size: int,
|
|
prompt_len: int,
|
|
input_: data_api.SequenceSample | None,
|
|
port: int,
|
|
):
|
|
|
|
# assert not torch.cuda.is_initialized()
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
|
torch.cuda.set_device(0)
|
|
assert world_size == (
|
|
dp * pp * tp
|
|
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
|
assert batch_size % dp == 0, (batch_size, dp)
|
|
|
|
# Initialize distributed environment.
|
|
model_name = ModelName("default", 0)
|
|
if not dist.is_initialized():
|
|
logger.info("Setting up distributed environment...")
|
|
dist.init_process_group(
|
|
"nccl",
|
|
rank=rank,
|
|
world_size=world_size,
|
|
init_method=f"tcp://localhost:{port}",
|
|
)
|
|
logger.info("Initialized distributed environment.")
|
|
init_global_constants(
|
|
num_dp=dp,
|
|
num_mp=tp,
|
|
num_pp=pp,
|
|
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
|
model_name=model_name,
|
|
max_prompt_len=prompt_len,
|
|
)
|
|
torch.cuda.set_device(0)
|
|
|
|
# NOTE: import here to avoid CUDA re-initialization
|
|
|
|
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
|
|
|
# Call a method like `config_from_llama` to get the config.
|
|
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
|
transformers.AutoConfig.from_pretrained(model_path)
|
|
)
|
|
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
|
|
mconfig.is_critic = is_critic
|
|
with constants.model_scope(model_name):
|
|
# Construct the model.
|
|
logger.info(f"Loading model from {model_path}...")
|
|
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
|
|
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
|
|
setattr(
|
|
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
|
|
)
|
|
module._instantiation_hooks.append(
|
|
lambda: getattr(module, f"from_{model_family_name}")(
|
|
load_dir=model_path,
|
|
init_critic_from_actor=is_critic,
|
|
)
|
|
)
|
|
add_helper_functions(module)
|
|
module.instantiate()
|
|
module.eval()
|
|
|
|
tokenizer = data_api.load_hf_tokenizer(model_path)
|
|
|
|
model = model_api.Model(
|
|
name=model_name,
|
|
module=module,
|
|
tokenizer=tokenizer,
|
|
device=module.device,
|
|
dtype=module.dtype,
|
|
)
|
|
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
|
from realhf.impl.model.backend.megatron import MegatronTrainBackend
|
|
|
|
backend = MegatronTrainBackend()
|
|
else:
|
|
from realhf.impl.model.backend.inference import PipelineInferenceBackend
|
|
|
|
backend = PipelineInferenceBackend()
|
|
|
|
logger.info("Running backend initialization...")
|
|
ft_spec = model_api.FinetuneSpec(
|
|
total_train_epochs=1,
|
|
dataset_size=128,
|
|
train_batch_size=128,
|
|
)
|
|
model = backend.initialize(model, ft_spec)
|
|
|
|
interface = model_api.make_interface(interface_impl)
|
|
|
|
if input_ is None:
|
|
input_ = _mock_input(batch_size, prompt_len)
|
|
|
|
input_ = input_.cuda()
|
|
|
|
mb_spec = model_api.MicroBatchSpec()
|
|
|
|
logger.info("Running interface computation...")
|
|
start = time.perf_counter_ns()
|
|
if interface_type == dfg.ModelInterfaceType.GENERATE:
|
|
res = interface.generate(model, input_, mb_spec)
|
|
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
|
res = interface.train_step(model, input_)
|
|
else:
|
|
res = interface.inference(model, input_, mb_spec)
|
|
|
|
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
|
|
if isinstance(res, data_api.SequenceSample):
|
|
res = res.cpu()
|
|
|
|
comsumed = time.perf_counter_ns() - start
|
|
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
|
|
return res
|
|
|
|
|
|
def run_function_call(
|
|
rpc_name: str,
|
|
model_path: str,
|
|
model_family_name: str,
|
|
batch_size: int,
|
|
prompt_len: int,
|
|
gen_len: int,
|
|
input_: data_api.SequenceSample | None,
|
|
) -> data_api.SequenceSample | None:
|
|
assert rpc_name in [
|
|
"actor_gen",
|
|
"actor_train",
|
|
"critic_inf",
|
|
"rew_inf",
|
|
"critic_train",
|
|
"ref_inf",
|
|
"ref_rw",
|
|
]
|
|
|
|
ref_rw_interface = dfg.ModelInterfaceAbstraction(
|
|
"ref_rw",
|
|
args=dict(
|
|
generation_config=dict(
|
|
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
|
),
|
|
rew_inf_args=dict(
|
|
tokenizer_path=model_path,
|
|
),
|
|
),
|
|
)
|
|
|
|
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
|
|
"ppo_actor",
|
|
args=dict(
|
|
generation_config=dict(
|
|
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
|
),
|
|
rew_inf_args=dict(
|
|
tokenizer_path=model_path,
|
|
),
|
|
),
|
|
)
|
|
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
|
|
rw_interface = dfg.ModelInterfaceAbstraction(
|
|
"paired_rw",
|
|
)
|
|
if rpc_name == "actor_gen":
|
|
interface_type = dfg.ModelInterfaceType.GENERATE
|
|
interface_impl = ppo_actor_interface
|
|
elif rpc_name == "actor_train":
|
|
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
|
interface_impl = ppo_actor_interface
|
|
elif rpc_name == "critic_inf":
|
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
|
interface_impl = ppo_critic_interface
|
|
elif rpc_name == "ref_inf":
|
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
|
interface_impl = ppo_actor_interface
|
|
elif rpc_name == "ref_rw":
|
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
|
interface_impl = ref_rw_interface
|
|
elif rpc_name == "critic_train":
|
|
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
|
interface_impl = ppo_critic_interface
|
|
else:
|
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
|
interface_impl = rw_interface
|
|
|
|
logger.info(f"Running RPC {rpc_name}...")
|
|
|
|
port = find_free_port()
|
|
res = funcion_call(
|
|
rank=0,
|
|
rpc_name=rpc_name,
|
|
world_size=1,
|
|
model_path=model_path,
|
|
model_family_name=model_family_name,
|
|
dp=1,
|
|
pp=1,
|
|
tp=1,
|
|
interface_type=interface_type,
|
|
interface_impl=interface_impl,
|
|
batch_size=batch_size,
|
|
prompt_len=prompt_len,
|
|
input_=input_,
|
|
port=port,
|
|
)
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
if isinstance(res, data_api.SequenceSample):
|
|
return res
|
|
else:
|
|
logger.info(f"RPC {rpc_name} stats: {res}")
|
|
|
|
|
|
def main():
|
|
mp.set_start_method("spawn", force=True)
|
|
|
|
model_family_name = "qwen2"
|
|
batch_size = 16
|
|
prompt_len = 128
|
|
gen_len = 4096
|
|
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
|
|
|
|
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
|
|
|
|
for i in range(2):
|
|
ref_rw_res = run_function_call(
|
|
"ref_rw",
|
|
model_family_name=model_family_name,
|
|
model_path=model_path,
|
|
batch_size=batch_size,
|
|
prompt_len=prompt_len,
|
|
gen_len=gen_len,
|
|
input_=None,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|