AReaL/realhf/search_engine/layers.py

311 lines
10 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
import pickle
import time
from collections import defaultdict
from typing import List, Optional, Union
import pandas as pd
import torch
import torch.distributed as dist
import transformers
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_package
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.utils.padding import unpad_input
logger = logging.getLogger("profile layers", "system")
def make_layers(config: ReaLModelConfig, dtype, device):
from realhf.impl.model.nn.real_llm_base import (
OutputHead,
ReaLModelBlock,
VocabPositionEmbedding,
)
embedding_layer = VocabPositionEmbedding(
config,
dtype=dtype,
device=device,
)
real_model_blocks = [
ReaLModelBlock(
config,
layer_index=i,
output_layernorm=(i == 1),
dtype=dtype,
device=device,
)
for i in range(1)
]
head = OutputHead(
config.hidden_dim,
1 if config.is_critic else config.vocab_size,
bias=False,
device=device,
dtype=dtype,
)
layer_names = ["embedding_layer", "block_0", "head"]
return [embedding_layer] + real_model_blocks + [head], layer_names
class ProfileLayers:
def __init__(
self,
model_name: str,
config: ReaLModelConfig,
tokenizer: transformers.PreTrainedTokenizerFast = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
self.model_name = model_name
self.config = config
self.backend_config = config_package.ModelBackend(
type_="deepspeed",
args=dict(
optimizer_name="adam",
optimizer_config=dict(lr=1e-5, weight_decay=0.0, betas=(0.9, 0.95)),
warmup_steps_proportion=0.0,
min_lr_ratio=0.0,
zero_stage=1,
bf16=False,
),
)
self.dtype = dtype
self.device = device
self.layers, self.layer_names = make_layers(config, dtype, device)
self.hidden_dim = config.hidden_dim
self.head_dim = config.head_dim
self.max_new_tokens = 128 # only useful in kv cache memory alloc
self.min_new_tokens = 128
self.stats = defaultdict(list)
self.num_layers = len(self.layers)
self.layers = [
model_api.Model(name, layer, tokenizer, device=device, dtype=dtype)
for layer, name in zip(self.layers, self.layer_names)
]
self.backend = model_api.make_backend(self.backend_config)
ft_spec = model_api.FinetuneSpec(10, 100, 10)
self.layers = [self.backend.initialize(layer, ft_spec) for layer in self.layers]
self.stats = defaultdict(list)
def reset_stats(self):
self.stats = defaultdict(list)
def insert_data_point(self, layer_name, name, bs, seq_len, time_ns):
self.stats["layer_name"].append(layer_name)
self.stats["op_name"].append(name)
self.stats["bs"].append(bs)
self.stats["seq_len"].append(seq_len)
self.stats["time_ns"].append(time_ns)
@torch.no_grad()
def fwd_gen(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
# fwd_gen_0
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens,
max_seqlen=int(max_seqlen),
store_kv_cache=True,
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_0", bs, seq_len, time.monotonic_ns() - st
)
prompt_logits = x.pp_output
logits = prompt_logits[cu_seqlens[1:] - 1]
input_lens = cu_seqlens[1:] - cu_seqlens[:-1]
cache_seqlens = input_lens.clone().to(dtype=torch.int32)
layer_indices = range(len(ys))
for y, layer_idx in zip(ys[1:-1], layer_indices[1:-1]):
assert (
y.k_cache is not None
and y.v_cache is not None
and y.cache_seqlens is not None
)
kvcache_seqlen = max(
max_seqlen + self.max_new_tokens,
self.hidden_dim // self.head_dim + 10,
)
# fix of a flash attention bug
k_cache = torch.zeros(
(bs, kvcache_seqlen, *y.k_cache.shape[1:]),
dtype=y.k_cache.dtype,
device=self.device,
)
v_cache = torch.zeros_like(k_cache)
indices = (
torch.arange(
kvcache_seqlen,
device=constants.current_device(),
dtype=torch.long,
)[None, :]
< input_lens[:, None]
)
k_cache[indices] = y.k_cache
v_cache[indices] = y.v_cache
y.k_cache = k_cache
y.v_cache = v_cache
y.cache_seqlens = cache_seqlens
x = PipeTransferData(store_kv_cache=True)
ys[0].cache_seqlens = cache_seqlens
# fwd_gen_1
new_tokens = torch.randint(
0,
self.config.vocab_size,
(bs,),
dtype=torch.long,
device=self.device,
)
ys[0].packed_input_ids = new_tokens
ys[0].packed_position_ids = None
x.cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=self.device)
x.max_seqlen = 1
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_1", bs, seq_len, time.monotonic_ns() - st
)
def fwd_bwd_opt(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, store_kv_cache=False
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
# fwd
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd", bs, seq_len, time.monotonic_ns() - st
)
# bwd
r = torch.rand(
*x.pp_output.shape,
device=x.pp_output.device,
dtype=x.pp_output.dtype,
)
loss = torch.max(x.pp_output * r)
st = time.monotonic_ns()
layer.module.backward(loss)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "bwd", bs, seq_len, time.monotonic_ns() - st
)
# opt
st = time.monotonic_ns()
layer.module.step()
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "opt", bs, seq_len, time.monotonic_ns() - st
)
x.pp_input = x.pp_output.clone().detach()
def make_dataframe_and_print(self):
df = pd.DataFrame(self.stats)
logger.info(f"Current Stats: \nstr{df}")
def dump_stats(self, world_size):
rank = dist.get_rank()
# dump full stats
dump_dir = os.path.join(
constants.PROFILER_CACHE_PATH,
"layer_stats",
)
dump_path = os.path.join(
dump_dir, self.model_name, f"layer-stats_{world_size}_{rank}.pkl"
)
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
with open(dump_path, "wb") as f:
df = pd.DataFrame(self.stats)
pickle.dump(df, f)
def make_profile_layers(
device: torch.device,
model_path: str,
model_name: str,
dtype: Optional[str] = None,
hf_model_type: str = "llama",
):
from realhf.impl.model.nn.real_llm_api import ReaLModel
if dtype == "fp16" or dtype == None:
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp32":
dtype == torch.float32
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
tokenizer = None
config: ReaLModelConfig = getattr(ReaLModel, f"config_from_{hf_model_type}")(
model_path=model_path,
)
if tokenizer is None:
tokenizer = model_api.load_hf_tokenizer(model_path)
profile_layers = ProfileLayers(
model_name, config, tokenizer=tokenizer, dtype=dtype, device=device
)
return profile_layers