PullRequest: 56 Support the cuda 12.8 image with megatron v0.11.0 and SGLang 0.4.4

Merge branch fw/megatron-v0.11.0 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/56

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* update trial
* add moe test script
* .
* .
* .
* .
* .
* .
* .
* .
* .
* remove gae2d
* .
This commit is contained in:
博惟 2025-03-25 16:02:10 +08:00
parent 46b7d3d32b
commit 6ccbb01ca8
8 changed files with 166 additions and 304 deletions

View File

@ -147,7 +147,6 @@ class vLLMConfig:
class SGLangConfig:
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
@ -169,7 +168,6 @@ class SGLangConfig:
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton"
sampling_backend: Optional[str] = None
@ -253,6 +251,13 @@ class MegatronConfig:
)
# Don't use MegatronOptimizerConfig here because OmegaConf
# does not recognize the annotation "torch.dtype"
overlap_param_gather_with_optimizer_step: bool = False
use_precision_aware_optimizer: bool = False
main_grads_dtype: str = "float32"
main_params_dtype: str = "float32"
exp_avg_dtype: str = "float32"
exp_avg_sq_dtype: str = "float32"
@dataclasses.dataclass

View File

@ -0,0 +1,49 @@
from importlib.metadata import version as get_version
from packaging.version import Version
def is_available(pkg_name):
return bool(get_version(pkg_name))
def compare_versions(version1: str, version2: str) -> int:
"""
Compare two version strings.
:param version1: First version string.
:param version2: Second version string.
:return: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2.
"""
v1 = Version(version1)
v2 = Version(version2)
if v1 < v2:
return -1
elif v1 == v2:
return 0
else:
return 1
def is_version_greater_or_equal(package_name: str, target_version: str) -> bool:
"""
Check if the installed version of a package is greater than or equal to the target version.
:param package_name: Name of the package.
:param target_version: Target version to compare against.
:return: True if the installed version is greater than or equal to the target version, False otherwise.
"""
installed_version = get_version(package_name)
return compare_versions(installed_version, target_version) >= 0
def is_version_less(package_name: str, target_version: str) -> bool:
"""
Check if the installed version of a package is less than the target version.
:param package_name: Name of the package.
:param target_version: Target version to compare against.
:return: True if the installed version is less than the target version, False otherwise.
"""
installed_version = get_version(package_name)
return compare_versions(installed_version, target_version) < 0

View File

@ -21,10 +21,16 @@ logger = logging.getLogger("model init")
# Import all model implementations.
_p = re.compile(r"^(?!.*__init__).*\.py$")
_filepath = os.path.dirname(__file__)
import_module(os.path.join(_filepath, "backend"), _p)
import_module(os.path.join(_filepath, "interface"), _p)
import_module(os.path.join(_filepath, "nn"), _p)
# NOTE: skip importing vLLM for now to avoid an
# "invalid device context" issue for the 25.01 image
import realhf.impl.model.backend.inference
import realhf.impl.model.backend.megatron
import realhf.impl.model.backend.mock_train
import realhf.impl.model.backend.sglang
# Set PyTorch JIT options, following Megatron-LM.
if torch.cuda.is_available():
torch._C._jit_set_profiling_executor(True)

View File

@ -23,7 +23,7 @@ from realhf.api.quickstart.model import (
MegatronConfig,
OptimizerConfig,
)
from realhf.base import constants, logging
from realhf.base import constants, logging, pkg_version
from realhf.base.datapack import flat2d
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
@ -34,6 +34,7 @@ from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
megatron_available = pkg_version.is_available("megatron.core")
try:
# Monkey patch
import megatron.core.optimizer as mcore_optim
@ -42,13 +43,15 @@ try:
def get_model_parallel_group(self):
return constants.parallelism_group()
def get_grad_stats_parallel_group(self):
return constants.parallelism_group()
mcore_optim.DistributedOptimizer = DistributedOptimizer
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel import (
DistributedDataParallel,
)
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
from megatron.core.optimizer.optimizer_config import (
OptimizerConfig as MegatronOptimizerConfig,
@ -57,7 +60,6 @@ try:
TransformerConfig as MegatronTransformerConfig,
)
megatron_available = True
except (ModuleNotFoundError, ImportError):
# importing megatron.core in CPU container will fail due to the requirement of apex
# Here class types must be defined for type hinting
@ -72,18 +74,13 @@ except (ModuleNotFoundError, ImportError):
if megatron_available:
try:
if pkg_version.is_version_greater_or_equal("megatron.core", "0.7.0"):
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
use_old_megatron = False
except (ModuleNotFoundError, ImportError):
# The above object is available in 0.9.0 but missing in 0.6.0
else:
from realhf.impl.model.backend.thirdparty.megatron.v0_6_0.lr_schduler import (
OptimizerParamScheduler,
)
use_old_megatron = True
WITHIN_MEGATRON_CONTEXT = False
@ -112,6 +109,13 @@ def megatron_ctx():
grid.get_data_parallel_group_gloo()
)
parallel_state._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = dist.get_process_group_ranks(g)
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
constants.data_parallel_group()
)
parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = (
grid.get_data_parallel_group_gloo()
)
# Build the context-parallel groups.
parallel_state._CONTEXT_PARALLEL_GROUP = constants.self_group()
@ -119,10 +123,17 @@ def megatron_ctx():
# Build the model-parallel groups.
parallel_state._MODEL_PARALLEL_GROUP = grid.get_model_parallel_group()
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
g = grid.get_model_parallel_group()
parallel_state._MODEL_PARALLEL_GLOBAL_RANKS = dist.get_process_group_ranks(g)
# Build the tensor model-parallel groups.
g = constants.model_parallel_group()
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
g = constants.model_parallel_group()
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = (
dist.get_process_group_ranks(g)
)
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
@ -145,16 +156,38 @@ def megatron_ctx():
# Build the tensor + data parallel groups.
parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP = grid.tp_dp_proc_group
parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = grid.tp_dp_proc_group
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
# Build the tensor + context parallel groups
parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = (
constants.model_parallel_group()
)
# Build the tensor + expert parallel groups
# Build expert parallel groups.
parallel_state._EXPERT_MODEL_PARALLEL_GROUP = constants.self_group()
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = constants.model_parallel_group()
g = constants.data_parallel_group()
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = g
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = g
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
grid.get_data_parallel_group_gloo()
)
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
parallel_state._EXPERT_TENSOR_PARALLEL_GROUP = constants.self_group()
parallel_state._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = constants.self_group()
parallel_state._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = (
grid.get_pipe_parallel_group()
)
parallel_state._EXPERT_DATA_PARALLEL_GROUP = constants.data_parallel_group()
parallel_state._EXPERT_DATA_PARALLEL_GROUP_GLOO = (
grid.get_data_parallel_group_gloo()
)
else:
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = (
constants.model_parallel_group()
)
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = (
constants.data_parallel_group()
)
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = (
constants.data_parallel_group()
)
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
grid.get_data_parallel_group_gloo()
)
# Remove the global memory buffer for megatron to save GPU memory.
parallel_state._GLOBAL_MEMORY_BUFFER = None
@ -514,12 +547,17 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
self, model: model_api.Model, spec: model_api.FinetuneSpec
) -> model_api.Model:
module = model.module
if not isinstance(module, ReaLModel):
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
if isinstance(self.ddp, dict):
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
from megatron.core.distributed.distributed_data_parallel_config import (
DistributedDataParallelConfig,
)
self.ddp = DistributedDataParallelConfig(**self.ddp)
with megatron_ctx():
if use_old_megatron:
if pkg_version.is_version_less("megatron.core", "0.7.0"):
module = DistributedDataParallel(
config=get_megatron_transformer_config(module.config),
module=module,
@ -544,11 +582,9 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
if self.ddp.use_distributed_optimizer:
# Remap parameters.
assert len(module.buffers) == 1
param_grad_buf: ParamAndGradBuffer = module.buffers[0]
param_grad_buf = module.buffers[0]
# Map Megatron flattened parameters to ReaLModel!
real_model.contiguous_param = param_grad_buf.param_data
# Sanity checks.
assert real_model._param_size == param_grad_buf.numel, (
real_model._param_size,
@ -590,11 +626,18 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
adam_beta2=betas[1],
adam_eps=self.optimizer.eps,
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
overlap_param_gather=self.ddp.overlap_param_gather,
clip_grad=self.optimizer.gradient_clipping,
log_num_zeros_in_grad=False,
)
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
opt_cfg.overlap_param_gather_with_optimizer_step = (
self.overlap_param_gather_with_optimizer_step
)
opt_cfg.use_precision_aware_optimizer = self.use_precision_aware_optimizer
opt_cfg.main_grads_dtype = getattr(torch, self.main_grads_dtype)
opt_cfg.main_params_dtype = getattr(torch, self.main_params_dtype)
opt_cfg.exp_avg_dtype = getattr(torch, self.exp_avg_dtype)
opt_cfg.exp_avg_sq_dtype = getattr(torch, self.exp_avg_sq_dtype)
with megatron_ctx():
# no_weight_decay_cond and scale_lr_cond have the following signature:
@ -632,13 +675,16 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
def destroy(self, model: model_api.Model):
assert isinstance(model.module, ReaLMegatronEngine)
optimizer = model.module.engine.optim
# The Megatron backend will register forward hooks that
# create circular references (grad -> param -> grad).
# Deleting models directly will not release the memory.
# We must disable hooks at first.
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
optimizer.disable_pre_hook()
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
model.module.module.engine.ddp.disable_forward_pre_hook()
else:
optimizer = model.module.engine.optim
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
optimizer.disable_pre_hook()
def save(self, model: model_api.Model, save_dir: str):
assert isinstance(model.module, ReaLMegatronEngine)

View File

@ -7,7 +7,6 @@ import os
import sys
import time
import traceback
from importlib.metadata import version
from typing import Dict, List, Tuple
import aiohttp
@ -16,7 +15,6 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from packaging.version import Version
from tqdm.asyncio import tqdm
from realhf.api.core import data_api
@ -32,7 +30,15 @@ from realhf.api.core.model_api import (
register_backend,
)
from realhf.api.quickstart.model import SGLangConfig
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
from realhf.base import (
cluster,
constants,
gpu_utils,
logging,
network,
pkg_version,
seeding,
)
logger = logging.getLogger("SGLang backend")
@ -41,6 +47,12 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
class SGLangAPIClient(LLMAPIClient):
async def _do_generate(
@ -95,7 +107,7 @@ class SGLangAPIClient(LLMAPIClient):
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["token_ids"]:
if data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
@ -107,7 +119,7 @@ class SGLangAPIClient(LLMAPIClient):
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output_ids = data["token_ids"]
output_ids = data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
finish_reason = data["meta_info"]["finish_reason"]
output_logprobs = data["meta_info"][
"output_token_logprobs"
@ -148,8 +160,7 @@ def sglang_server_process(server_args_dict):
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
sglang_version = version("sglang")
if Version(sglang_version) < Version("0.4.3"):
if pkg_version.is_version_less("sglang", "0.4.3"):
from sglang.srt.server import launch_server
server_args_dict.pop("enable_nccl_nvls")
@ -423,7 +434,7 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
tp_size=constants.model_parallel_world_size(),
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
file_storage_pth=os.path.join(
file_storage_path=os.path.join(
constants.SGLANG_CACHE_PATH,
f"sglang_storage{constants.data_parallel_rank()}",
),

View File

@ -8,9 +8,7 @@ from typing import Dict, Optional, Tuple
import torch
import torch.distributed
import realhf.base.constants as constants
from realhf.impl.model.parallelism.model_parallel.utils import VocabUtility
from realhf.impl.model.utils.functional import build_leave_one_indices
from realhf.base import pkg_version
class KLController:
@ -270,31 +268,6 @@ def get_packed_reward_dense(
return kl_rewards, tot_rewards
def pygae2d_olp(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
dones: torch.BoolTensor,
truncates: torch.BoolTensor,
gamma: float,
lam: float,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
episode_length = int(rewards.shape[1])
masks = 1 - dones.float()
truncate_mask = 1 - truncates.float()
delta = rewards + gamma * values[:, 1:] * masks[:, 1:] - values[:, :-1]
adv = torch.zeros_like(rewards)
gae = torch.zeros_like(rewards[:, 0])
m = gamma * lam * masks[:, 1:]
step = episode_length - 1
while step >= 0:
# if env is terminated compulsively, then abandon the finnal step
# i.e. advantage of final step is 0, values target of final step is predicted values
gae = (delta[:, step] + m[:, step] * gae) * truncate_mask[:, step + 1]
adv[:, step] = gae
step -= 1
return adv, adv + values[:, :-1]
def pygae1d_nolp_misalign(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
@ -329,37 +302,6 @@ def pygae1d_nolp_misalign(
return advantages, returns
def pygae2d_nolp(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
on_reset: torch.BoolTensor,
truncates: torch.BoolTensor,
gamma: float,
lam: float,
) -> torch.FloatTensor:
on_reset = on_reset.float()
truncates = truncates.float()
episode_length = int(rewards.shape[1])
delta = rewards + gamma * values[:, 1:] * (1 - on_reset[:, 1:]) - values[:, :-1]
gae = torch.zeros_like(rewards[:, 0])
adv = torch.zeros_like(rewards)
# 1. If the next step is a new episode, GAE doesn't propagate back
# 2. If the next step is a truncated final step, the backpropagated GAE is -V(t),
# which is not correct. We ignore it such that the current GAE is r(t-1)+ɣV(t)-V(t-1)
# 3. If the next step is a done final step, the backpropagated GAE is zero.
m = gamma * lam * (1 - on_reset[:, 1:]) * (1 - truncates[:, 1:])
step = episode_length - 1
while step >= 0:
gae = delta[:, step] + m[:, step] * gae
adv[:, step] = gae
step -= 1
return adv, adv + values[:, :-1]
def cugae1d_nolp_misalign_func(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
@ -393,134 +335,14 @@ def cugae1d_nolp_misalign_func(
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
Both have the same shape as rewards.
"""
import realhf._C.cugae as gae_cuda
if pkg_version.is_available("cugae"):
from cugae import cugae1d_nolp_misalign_func as gae_1d_nolp_misalign
else:
from realhf._C.cugae import gae_1d_nolp_misalign
assert len(rewards.shape) == len(values.shape) == len(cu_seqlens.shape) == 1
assert cu_seqlens[0] == 0 and cu_seqlens[-1] == rewards.shape[0]
return gae_cuda.gae_1d_nolp_misalign(
rewards, values, cu_seqlens, truncate, gamma, lam
)
def cugae2d_olp_func(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
dones: torch.BoolTensor,
truncates: torch.BoolTensor,
gamma: float,
lam: float,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Compute GAE over batched sequences with variable lengths, assuming
overlapped sequences.
This function assumes that rewards and values are batched as 2D tensors.
The first dimension is batch_size and the second dimension is the number of collected timesteps.
Each batch slot may contain multiple sequences, and sequences may have different lengths.
The length of each sequence is marked by dones.
`dones` marks the termination of each sequence, no matter it's truncated or not.
`truncates` marks truncation and its nonzero indices must be the subset of `dones`.
If truncate, abandon GAE computation of the last step (because we don't have the bootstrapped
value in this case) and start from the second last step.
The final step of each sequence *is overlapped* by the first step of the next sequence,
i.e., auto-reset, which has widely used in libraries such as gym. In other words, the
steps where `dones` is True are actually the first steps of sequences. Therefore,
this function is suffixed with "olp" (overlap).
Args:
rewards (torch.FloatTensor): Shape [batch_size, seqlen].
values (torch.FloatTensor): Shape [batch_size, seqlen + 1], with one more bootstrap step.
dones (torch.BoolTensor): Shape [batch_size, seqlen + 1].
truncates (torch.BoolTensor): Shape [batch_size, seqlen + 1].
gamma (float): Discount factor.
lam (float): GAE discount factor.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
Both have the same shape as rewards.
"""
import realhf._C.cugae as gae_cuda
truncates_indices = truncates.nonzero()
assert torch.all(dones[truncates_indices[:, 0], truncates_indices[:, 1]])
done_indices = dones.nonzero()
num_dones = dones.float().sum(1)
max_num_dones = int(num_dones.max())
cu_num_dones = torch.nn.functional.pad(num_dones.cumsum(0), (1, 0), value=0).int()
is_truncate = truncates[done_indices[:, 0], done_indices[:, 1]]
return gae_cuda.gae_2d_olp(
rewards,
values,
dones,
done_indices[:, 1].int(),
cu_num_dones,
max_num_dones,
is_truncate,
gamma,
lam,
)
def cugae2d_nolp_func(
rewards: torch.FloatTensor,
values: torch.FloatTensor,
on_reset: torch.BoolTensor,
truncates: torch.BoolTensor,
gamma: float,
lam: float,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Compute GAE over batched sequences with variable lengths, assuming non-
overlapped sequences.
This function assumes that rewards and values are batched as 2D tensors.
The first dimension is batch_size and the second dimension is the number of collected timesteps.
Each batch slot may contain multiple sequences, and sequences may have different lengths.
The length of each sequence is marked by `on_reset`.
`on_reset` marks the beginning of each sequence. `truncates` marks truncation.
If truncate, values will be bootstrapped from the `done` step.
The final step of each sequence is *NOT* overlapped by the first step of the next sequence.
Each sequence will be complete. The last step should only have observations but no rewards.
This is used in SRL. Therefore, this function is suffixed with "nolp" (non-overlap).
Args:
rewards (torch.FloatTensor): Shape [batch_size, seqlen].
values (torch.FloatTensor): Shape [batch_size, seqlen + 1], with one more bootstrap step.
dones (torch.BoolTensor): Shape [batch_size, seqlen + 1].
truncates (torch.BoolTensor): Shape [batch_size, seqlen + 1].
gamma (float): Discount factor.
lam (float): GAE discount factor.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
Both have the same shape as rewards.
"""
import realhf._C.cugae as gae_cuda
dones = on_reset[:, 1:]
truncates_indices = truncates[:, :-1].nonzero()
assert torch.all(dones[truncates_indices[:, 0], truncates_indices[:, 1]])
on_reset_indices = on_reset.nonzero()
num_resets = on_reset.float().sum(1)
max_num_resets = int(num_resets.max())
cu_num_resets = torch.nn.functional.pad(num_resets.cumsum(0), (1, 0), value=0).int()
truncates = torch.cat(
[torch.zeros_like(truncates[:, 0:1]), truncates[:, :-1]], dim=1
)
bootstrap = truncates[on_reset_indices[:, 0], on_reset_indices[:, 1]]
return gae_cuda.gae_2d_nolp(
rewards,
values,
on_reset,
on_reset_indices[:, 1].int(),
cu_num_resets,
max_num_resets,
bootstrap,
gamma,
lam,
)
return gae_1d_nolp_misalign(rewards, values, cu_seqlens, truncate, gamma, lam)
@torch.no_grad()

View File

@ -9,11 +9,7 @@ import torch
from realhf.impl.model.utils.ppo_functional import (
cugae1d_nolp_misalign_func,
cugae2d_nolp_func,
cugae2d_olp_func,
pygae1d_nolp_misalign,
pygae2d_nolp,
pygae2d_olp,
)
@ -58,78 +54,3 @@ def test_gae1d_nolp_misalign(max_seqlen: int, bs: int, gamma: float, lam: float)
f"max_seqlen={max_seqlen},bs={bs}, CUDA acceleration ratio",
(t2 - t1) / (t3 - t2),
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
@pytest.mark.parametrize("seqlen", [32, 128, 512, 1024])
@pytest.mark.parametrize("bs", [8, 16, 32, 100])
@pytest.mark.parametrize("gamma", [0.9, 1.0])
@pytest.mark.parametrize("lam", [0.5, 1.0])
def test_gae2d_olp(bs: int, seqlen: int, gamma: float, lam: float):
torch.random.manual_seed(0)
rewards = torch.randn(bs, seqlen).cuda()
values = torch.randn(bs, seqlen + 1).cuda()
dones = torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda()
truncates = dones.logical_and(torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda())
py_adv, py_ret = pygae2d_olp(rewards, values, dones, truncates, gamma, lam)
adv, ret = cugae2d_olp_func(rewards, values, dones, truncates, gamma, lam)
torch.cuda.synchronize()
t1 = time.perf_counter_ns()
py_adv, py_ret = pygae2d_olp(rewards, values, dones, truncates, gamma, lam)
torch.cuda.synchronize()
t2 = time.perf_counter_ns()
adv, ret = cugae2d_olp_func(rewards, values, dones, truncates, gamma, lam)
torch.cuda.synchronize()
t3 = time.perf_counter_ns()
assert torch.allclose(adv, py_adv, atol=1e-5), (adv - py_adv).abs().max()
assert torch.allclose(ret, py_ret, atol=1e-5), (ret - py_ret).abs().max()
print(
f"seqlen={seqlen},bs={bs}, CUDA acceleration ratio",
(t2 - t1) / (t3 - t2),
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
@pytest.mark.parametrize("seqlen", [32, 128, 512, 1024])
@pytest.mark.parametrize("bs", [8, 16, 32, 100])
@pytest.mark.parametrize("gamma", [0.9, 1.0])
@pytest.mark.parametrize("lam", [0.5, 1.0])
def test_gae2d_nolp(bs: int, seqlen: int, gamma: float, lam: float):
torch.random.manual_seed(0)
rewards = torch.randn(bs, seqlen).cuda()
values = torch.randn(bs, seqlen + 1).cuda()
on_reset_ = torch.randint(0, 2, (bs, seqlen + 2)).bool().cuda()
on_reset = on_reset_[:, :-1].contiguous()
truncates = (
on_reset_[:, 1:]
.logical_and(torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda())
.contiguous()
)
py_adv, py_ret = pygae2d_nolp(rewards, values, on_reset, truncates, gamma, lam)
adv, ret = cugae2d_nolp_func(rewards, values, on_reset, truncates, gamma, lam)
torch.cuda.synchronize()
t1 = time.perf_counter_ns()
py_adv, py_ret = pygae2d_nolp(rewards, values, on_reset, truncates, gamma, lam)
torch.cuda.synchronize()
t2 = time.perf_counter_ns()
adv, ret = cugae2d_nolp_func(rewards, values, on_reset, truncates, gamma, lam)
torch.cuda.synchronize()
t3 = time.perf_counter_ns()
adv = adv * (1 - truncates[:, :-1].float())
py_adv = py_adv * (1 - truncates[:, :-1].float())
ret = ret * (1 - truncates[:, :-1].float())
py_ret = py_ret * (1 - truncates[:, :-1].float())
assert torch.allclose(adv, py_adv, atol=1e-5), (adv - py_adv).abs().max()
assert torch.allclose(ret, py_ret, atol=1e-5), (ret - py_ret).abs().max()
print(
f"seqlen={seqlen},bs={bs}, CUDA acceleration ratio",
(t2 - t1) / (t3 - t2),
)

View File

@ -112,7 +112,10 @@ def test_fn(
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
backend = SGLangGenerationBackend(model_path=path)
backend = SGLangGenerationBackend(
model_path=path,
dtype="bfloat16" if module.dtype == torch.bfloat16 else torch.float16,
)
model = model_api.Model(
name=model_name,
module=module,
@ -188,7 +191,6 @@ def test_fn(
print("success")
# 清理分布式环境
dist.destroy_process_group()