mirror of https://github.com/inclusionAI/AReaL
PullRequest: 56 Support the cuda 12.8 image with megatron v0.11.0 and SGLang 0.4.4
Merge branch fw/megatron-v0.11.0 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/56 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * update trial * add moe test script * . * . * . * . * . * . * . * . * . * remove gae2d * .
This commit is contained in:
parent
46b7d3d32b
commit
6ccbb01ca8
|
@ -147,7 +147,6 @@ class vLLMConfig:
|
|||
class SGLangConfig:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_jump_forward: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
|
@ -169,7 +168,6 @@ class SGLangConfig:
|
|||
num_continuous_decode_steps: int = 1
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
return_hidden_states: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "triton"
|
||||
sampling_backend: Optional[str] = None
|
||||
|
@ -253,6 +251,13 @@ class MegatronConfig:
|
|||
)
|
||||
# Don't use MegatronOptimizerConfig here because OmegaConf
|
||||
# does not recognize the annotation "torch.dtype"
|
||||
overlap_param_gather_with_optimizer_step: bool = False
|
||||
|
||||
use_precision_aware_optimizer: bool = False
|
||||
main_grads_dtype: str = "float32"
|
||||
main_params_dtype: str = "float32"
|
||||
exp_avg_dtype: str = "float32"
|
||||
exp_avg_sq_dtype: str = "float32"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
from importlib.metadata import version as get_version
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def is_available(pkg_name):
|
||||
return bool(get_version(pkg_name))
|
||||
|
||||
|
||||
def compare_versions(version1: str, version2: str) -> int:
|
||||
"""
|
||||
Compare two version strings.
|
||||
|
||||
:param version1: First version string.
|
||||
:param version2: Second version string.
|
||||
:return: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2.
|
||||
"""
|
||||
v1 = Version(version1)
|
||||
v2 = Version(version2)
|
||||
if v1 < v2:
|
||||
return -1
|
||||
elif v1 == v2:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_version_greater_or_equal(package_name: str, target_version: str) -> bool:
|
||||
"""
|
||||
Check if the installed version of a package is greater than or equal to the target version.
|
||||
|
||||
:param package_name: Name of the package.
|
||||
:param target_version: Target version to compare against.
|
||||
:return: True if the installed version is greater than or equal to the target version, False otherwise.
|
||||
"""
|
||||
installed_version = get_version(package_name)
|
||||
return compare_versions(installed_version, target_version) >= 0
|
||||
|
||||
|
||||
def is_version_less(package_name: str, target_version: str) -> bool:
|
||||
"""
|
||||
Check if the installed version of a package is less than the target version.
|
||||
|
||||
:param package_name: Name of the package.
|
||||
:param target_version: Target version to compare against.
|
||||
:return: True if the installed version is less than the target version, False otherwise.
|
||||
"""
|
||||
installed_version = get_version(package_name)
|
||||
return compare_versions(installed_version, target_version) < 0
|
|
@ -21,10 +21,16 @@ logger = logging.getLogger("model init")
|
|||
# Import all model implementations.
|
||||
_p = re.compile(r"^(?!.*__init__).*\.py$")
|
||||
_filepath = os.path.dirname(__file__)
|
||||
import_module(os.path.join(_filepath, "backend"), _p)
|
||||
import_module(os.path.join(_filepath, "interface"), _p)
|
||||
import_module(os.path.join(_filepath, "nn"), _p)
|
||||
|
||||
# NOTE: skip importing vLLM for now to avoid an
|
||||
# "invalid device context" issue for the 25.01 image
|
||||
import realhf.impl.model.backend.inference
|
||||
import realhf.impl.model.backend.megatron
|
||||
import realhf.impl.model.backend.mock_train
|
||||
import realhf.impl.model.backend.sglang
|
||||
|
||||
# Set PyTorch JIT options, following Megatron-LM.
|
||||
if torch.cuda.is_available():
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
|
|
|
@ -23,7 +23,7 @@ from realhf.api.quickstart.model import (
|
|||
MegatronConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base import constants, logging, pkg_version
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
|
||||
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
|
||||
|
@ -34,6 +34,7 @@ from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|||
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
|
||||
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
|
||||
|
||||
megatron_available = pkg_version.is_available("megatron.core")
|
||||
try:
|
||||
# Monkey patch
|
||||
import megatron.core.optimizer as mcore_optim
|
||||
|
@ -42,13 +43,15 @@ try:
|
|||
def get_model_parallel_group(self):
|
||||
return constants.parallelism_group()
|
||||
|
||||
def get_grad_stats_parallel_group(self):
|
||||
return constants.parallelism_group()
|
||||
|
||||
mcore_optim.DistributedOptimizer = DistributedOptimizer
|
||||
|
||||
from megatron.core import parallel_state
|
||||
from megatron.core.distributed.distributed_data_parallel import (
|
||||
DistributedDataParallel,
|
||||
)
|
||||
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
|
||||
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
|
||||
from megatron.core.optimizer.optimizer_config import (
|
||||
OptimizerConfig as MegatronOptimizerConfig,
|
||||
|
@ -57,7 +60,6 @@ try:
|
|||
TransformerConfig as MegatronTransformerConfig,
|
||||
)
|
||||
|
||||
megatron_available = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# importing megatron.core in CPU container will fail due to the requirement of apex
|
||||
# Here class types must be defined for type hinting
|
||||
|
@ -72,18 +74,13 @@ except (ModuleNotFoundError, ImportError):
|
|||
|
||||
|
||||
if megatron_available:
|
||||
try:
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.7.0"):
|
||||
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
|
||||
|
||||
use_old_megatron = False
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# The above object is available in 0.9.0 but missing in 0.6.0
|
||||
else:
|
||||
from realhf.impl.model.backend.thirdparty.megatron.v0_6_0.lr_schduler import (
|
||||
OptimizerParamScheduler,
|
||||
)
|
||||
|
||||
use_old_megatron = True
|
||||
|
||||
|
||||
WITHIN_MEGATRON_CONTEXT = False
|
||||
|
||||
|
@ -112,6 +109,13 @@ def megatron_ctx():
|
|||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
parallel_state._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = dist.get_process_group_ranks(g)
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
|
||||
constants.data_parallel_group()
|
||||
)
|
||||
parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = (
|
||||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
|
||||
# Build the context-parallel groups.
|
||||
parallel_state._CONTEXT_PARALLEL_GROUP = constants.self_group()
|
||||
|
@ -119,10 +123,17 @@ def megatron_ctx():
|
|||
|
||||
# Build the model-parallel groups.
|
||||
parallel_state._MODEL_PARALLEL_GROUP = grid.get_model_parallel_group()
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
g = grid.get_model_parallel_group()
|
||||
parallel_state._MODEL_PARALLEL_GLOBAL_RANKS = dist.get_process_group_ranks(g)
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
g = constants.model_parallel_group()
|
||||
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
g = constants.model_parallel_group()
|
||||
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = (
|
||||
dist.get_process_group_ranks(g)
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups and embedding groups
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
|
@ -145,16 +156,38 @@ def megatron_ctx():
|
|||
# Build the tensor + data parallel groups.
|
||||
parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP = grid.tp_dp_proc_group
|
||||
parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = grid.tp_dp_proc_group
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
# Build the tensor + context parallel groups
|
||||
parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = (
|
||||
constants.model_parallel_group()
|
||||
)
|
||||
|
||||
# Build the tensor + expert parallel groups
|
||||
# Build expert parallel groups.
|
||||
parallel_state._EXPERT_MODEL_PARALLEL_GROUP = constants.self_group()
|
||||
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = constants.model_parallel_group()
|
||||
g = constants.data_parallel_group()
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = g
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = g
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
|
||||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
parallel_state._EXPERT_TENSOR_PARALLEL_GROUP = constants.self_group()
|
||||
parallel_state._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = constants.self_group()
|
||||
parallel_state._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = (
|
||||
grid.get_pipe_parallel_group()
|
||||
)
|
||||
parallel_state._EXPERT_DATA_PARALLEL_GROUP = constants.data_parallel_group()
|
||||
parallel_state._EXPERT_DATA_PARALLEL_GROUP_GLOO = (
|
||||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
else:
|
||||
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = (
|
||||
constants.model_parallel_group()
|
||||
)
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = (
|
||||
constants.data_parallel_group()
|
||||
)
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = (
|
||||
constants.data_parallel_group()
|
||||
)
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = (
|
||||
grid.get_data_parallel_group_gloo()
|
||||
)
|
||||
|
||||
# Remove the global memory buffer for megatron to save GPU memory.
|
||||
parallel_state._GLOBAL_MEMORY_BUFFER = None
|
||||
|
@ -514,12 +547,17 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
self, model: model_api.Model, spec: model_api.FinetuneSpec
|
||||
) -> model_api.Model:
|
||||
module = model.module
|
||||
|
||||
if not isinstance(module, ReaLModel):
|
||||
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
|
||||
if isinstance(self.ddp, dict):
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
from megatron.core.distributed.distributed_data_parallel_config import (
|
||||
DistributedDataParallelConfig,
|
||||
)
|
||||
self.ddp = DistributedDataParallelConfig(**self.ddp)
|
||||
with megatron_ctx():
|
||||
if use_old_megatron:
|
||||
if pkg_version.is_version_less("megatron.core", "0.7.0"):
|
||||
module = DistributedDataParallel(
|
||||
config=get_megatron_transformer_config(module.config),
|
||||
module=module,
|
||||
|
@ -544,11 +582,9 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
if self.ddp.use_distributed_optimizer:
|
||||
# Remap parameters.
|
||||
assert len(module.buffers) == 1
|
||||
param_grad_buf: ParamAndGradBuffer = module.buffers[0]
|
||||
|
||||
param_grad_buf = module.buffers[0]
|
||||
# Map Megatron flattened parameters to ReaLModel!
|
||||
real_model.contiguous_param = param_grad_buf.param_data
|
||||
|
||||
# Sanity checks.
|
||||
assert real_model._param_size == param_grad_buf.numel, (
|
||||
real_model._param_size,
|
||||
|
@ -590,11 +626,18 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
adam_beta2=betas[1],
|
||||
adam_eps=self.optimizer.eps,
|
||||
use_distributed_optimizer=self.ddp.use_distributed_optimizer,
|
||||
overlap_grad_reduce=self.ddp.overlap_grad_reduce,
|
||||
overlap_param_gather=self.ddp.overlap_param_gather,
|
||||
clip_grad=self.optimizer.gradient_clipping,
|
||||
log_num_zeros_in_grad=False,
|
||||
)
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
opt_cfg.overlap_param_gather_with_optimizer_step = (
|
||||
self.overlap_param_gather_with_optimizer_step
|
||||
)
|
||||
opt_cfg.use_precision_aware_optimizer = self.use_precision_aware_optimizer
|
||||
opt_cfg.main_grads_dtype = getattr(torch, self.main_grads_dtype)
|
||||
opt_cfg.main_params_dtype = getattr(torch, self.main_params_dtype)
|
||||
opt_cfg.exp_avg_dtype = getattr(torch, self.exp_avg_dtype)
|
||||
opt_cfg.exp_avg_sq_dtype = getattr(torch, self.exp_avg_sq_dtype)
|
||||
|
||||
with megatron_ctx():
|
||||
# no_weight_decay_cond and scale_lr_cond have the following signature:
|
||||
|
@ -632,13 +675,16 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
|
||||
def destroy(self, model: model_api.Model):
|
||||
assert isinstance(model.module, ReaLMegatronEngine)
|
||||
optimizer = model.module.engine.optim
|
||||
# The Megatron backend will register forward hooks that
|
||||
# create circular references (grad -> param -> grad).
|
||||
# Deleting models directly will not release the memory.
|
||||
# We must disable hooks at first.
|
||||
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
|
||||
optimizer.disable_pre_hook()
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
model.module.module.engine.ddp.disable_forward_pre_hook()
|
||||
else:
|
||||
optimizer = model.module.engine.optim
|
||||
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
|
||||
optimizer.disable_pre_hook()
|
||||
|
||||
def save(self, model: model_api.Model, save_dir: str):
|
||||
assert isinstance(model.module, ReaLMegatronEngine)
|
||||
|
|
|
@ -7,7 +7,6 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib.metadata import version
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
@ -16,7 +15,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from realhf.api.core import data_api
|
||||
|
@ -32,7 +30,15 @@ from realhf.api.core.model_api import (
|
|||
register_backend,
|
||||
)
|
||||
from realhf.api.quickstart.model import SGLangConfig
|
||||
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
|
||||
from realhf.base import (
|
||||
cluster,
|
||||
constants,
|
||||
gpu_utils,
|
||||
logging,
|
||||
network,
|
||||
pkg_version,
|
||||
seeding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("SGLang backend")
|
||||
|
||||
|
@ -41,6 +47,12 @@ def remove_prefix(text: str, prefix: str) -> str:
|
|||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
|
||||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
|
||||
class SGLangAPIClient(LLMAPIClient):
|
||||
|
||||
async def _do_generate(
|
||||
|
@ -95,7 +107,7 @@ class SGLangAPIClient(LLMAPIClient):
|
|||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["token_ids"]:
|
||||
if data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
|
@ -107,7 +119,7 @@ class SGLangAPIClient(LLMAPIClient):
|
|||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output_ids = data["token_ids"]
|
||||
output_ids = data[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
finish_reason = data["meta_info"]["finish_reason"]
|
||||
output_logprobs = data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
|
@ -148,8 +160,7 @@ def sglang_server_process(server_args_dict):
|
|||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
sglang_version = version("sglang")
|
||||
if Version(sglang_version) < Version("0.4.3"):
|
||||
if pkg_version.is_version_less("sglang", "0.4.3"):
|
||||
from sglang.srt.server import launch_server
|
||||
|
||||
server_args_dict.pop("enable_nccl_nvls")
|
||||
|
@ -423,7 +434,7 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
|||
tp_size=constants.model_parallel_world_size(),
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
||||
file_storage_pth=os.path.join(
|
||||
file_storage_path=os.path.join(
|
||||
constants.SGLANG_CACHE_PATH,
|
||||
f"sglang_storage{constants.data_parallel_rank()}",
|
||||
),
|
||||
|
|
|
@ -8,9 +8,7 @@ from typing import Dict, Optional, Tuple
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.impl.model.parallelism.model_parallel.utils import VocabUtility
|
||||
from realhf.impl.model.utils.functional import build_leave_one_indices
|
||||
from realhf.base import pkg_version
|
||||
|
||||
|
||||
class KLController:
|
||||
|
@ -270,31 +268,6 @@ def get_packed_reward_dense(
|
|||
return kl_rewards, tot_rewards
|
||||
|
||||
|
||||
def pygae2d_olp(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
dones: torch.BoolTensor,
|
||||
truncates: torch.BoolTensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
episode_length = int(rewards.shape[1])
|
||||
masks = 1 - dones.float()
|
||||
truncate_mask = 1 - truncates.float()
|
||||
delta = rewards + gamma * values[:, 1:] * masks[:, 1:] - values[:, :-1]
|
||||
adv = torch.zeros_like(rewards)
|
||||
gae = torch.zeros_like(rewards[:, 0])
|
||||
m = gamma * lam * masks[:, 1:]
|
||||
step = episode_length - 1
|
||||
while step >= 0:
|
||||
# if env is terminated compulsively, then abandon the finnal step
|
||||
# i.e. advantage of final step is 0, values target of final step is predicted values
|
||||
gae = (delta[:, step] + m[:, step] * gae) * truncate_mask[:, step + 1]
|
||||
adv[:, step] = gae
|
||||
step -= 1
|
||||
return adv, adv + values[:, :-1]
|
||||
|
||||
|
||||
def pygae1d_nolp_misalign(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
|
@ -329,37 +302,6 @@ def pygae1d_nolp_misalign(
|
|||
return advantages, returns
|
||||
|
||||
|
||||
def pygae2d_nolp(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
on_reset: torch.BoolTensor,
|
||||
truncates: torch.BoolTensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> torch.FloatTensor:
|
||||
on_reset = on_reset.float()
|
||||
truncates = truncates.float()
|
||||
episode_length = int(rewards.shape[1])
|
||||
delta = rewards + gamma * values[:, 1:] * (1 - on_reset[:, 1:]) - values[:, :-1]
|
||||
|
||||
gae = torch.zeros_like(rewards[:, 0])
|
||||
adv = torch.zeros_like(rewards)
|
||||
|
||||
# 1. If the next step is a new episode, GAE doesn't propagate back
|
||||
# 2. If the next step is a truncated final step, the backpropagated GAE is -V(t),
|
||||
# which is not correct. We ignore it such that the current GAE is r(t-1)+ɣV(t)-V(t-1)
|
||||
# 3. If the next step is a done final step, the backpropagated GAE is zero.
|
||||
m = gamma * lam * (1 - on_reset[:, 1:]) * (1 - truncates[:, 1:])
|
||||
|
||||
step = episode_length - 1
|
||||
while step >= 0:
|
||||
gae = delta[:, step] + m[:, step] * gae
|
||||
adv[:, step] = gae
|
||||
step -= 1
|
||||
|
||||
return adv, adv + values[:, :-1]
|
||||
|
||||
|
||||
def cugae1d_nolp_misalign_func(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
|
@ -393,134 +335,14 @@ def cugae1d_nolp_misalign_func(
|
|||
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
|
||||
Both have the same shape as rewards.
|
||||
"""
|
||||
import realhf._C.cugae as gae_cuda
|
||||
if pkg_version.is_available("cugae"):
|
||||
from cugae import cugae1d_nolp_misalign_func as gae_1d_nolp_misalign
|
||||
else:
|
||||
from realhf._C.cugae import gae_1d_nolp_misalign
|
||||
|
||||
assert len(rewards.shape) == len(values.shape) == len(cu_seqlens.shape) == 1
|
||||
assert cu_seqlens[0] == 0 and cu_seqlens[-1] == rewards.shape[0]
|
||||
return gae_cuda.gae_1d_nolp_misalign(
|
||||
rewards, values, cu_seqlens, truncate, gamma, lam
|
||||
)
|
||||
|
||||
|
||||
def cugae2d_olp_func(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
dones: torch.BoolTensor,
|
||||
truncates: torch.BoolTensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute GAE over batched sequences with variable lengths, assuming
|
||||
overlapped sequences.
|
||||
|
||||
This function assumes that rewards and values are batched as 2D tensors.
|
||||
The first dimension is batch_size and the second dimension is the number of collected timesteps.
|
||||
Each batch slot may contain multiple sequences, and sequences may have different lengths.
|
||||
The length of each sequence is marked by dones.
|
||||
|
||||
`dones` marks the termination of each sequence, no matter it's truncated or not.
|
||||
`truncates` marks truncation and its nonzero indices must be the subset of `dones`.
|
||||
If truncate, abandon GAE computation of the last step (because we don't have the bootstrapped
|
||||
value in this case) and start from the second last step.
|
||||
|
||||
The final step of each sequence *is overlapped* by the first step of the next sequence,
|
||||
i.e., auto-reset, which has widely used in libraries such as gym. In other words, the
|
||||
steps where `dones` is True are actually the first steps of sequences. Therefore,
|
||||
this function is suffixed with "olp" (overlap).
|
||||
|
||||
Args:
|
||||
rewards (torch.FloatTensor): Shape [batch_size, seqlen].
|
||||
values (torch.FloatTensor): Shape [batch_size, seqlen + 1], with one more bootstrap step.
|
||||
dones (torch.BoolTensor): Shape [batch_size, seqlen + 1].
|
||||
truncates (torch.BoolTensor): Shape [batch_size, seqlen + 1].
|
||||
gamma (float): Discount factor.
|
||||
lam (float): GAE discount factor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
|
||||
Both have the same shape as rewards.
|
||||
"""
|
||||
import realhf._C.cugae as gae_cuda
|
||||
|
||||
truncates_indices = truncates.nonzero()
|
||||
assert torch.all(dones[truncates_indices[:, 0], truncates_indices[:, 1]])
|
||||
done_indices = dones.nonzero()
|
||||
num_dones = dones.float().sum(1)
|
||||
max_num_dones = int(num_dones.max())
|
||||
cu_num_dones = torch.nn.functional.pad(num_dones.cumsum(0), (1, 0), value=0).int()
|
||||
is_truncate = truncates[done_indices[:, 0], done_indices[:, 1]]
|
||||
return gae_cuda.gae_2d_olp(
|
||||
rewards,
|
||||
values,
|
||||
dones,
|
||||
done_indices[:, 1].int(),
|
||||
cu_num_dones,
|
||||
max_num_dones,
|
||||
is_truncate,
|
||||
gamma,
|
||||
lam,
|
||||
)
|
||||
|
||||
|
||||
def cugae2d_nolp_func(
|
||||
rewards: torch.FloatTensor,
|
||||
values: torch.FloatTensor,
|
||||
on_reset: torch.BoolTensor,
|
||||
truncates: torch.BoolTensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute GAE over batched sequences with variable lengths, assuming non-
|
||||
overlapped sequences.
|
||||
|
||||
This function assumes that rewards and values are batched as 2D tensors.
|
||||
The first dimension is batch_size and the second dimension is the number of collected timesteps.
|
||||
Each batch slot may contain multiple sequences, and sequences may have different lengths.
|
||||
The length of each sequence is marked by `on_reset`.
|
||||
|
||||
`on_reset` marks the beginning of each sequence. `truncates` marks truncation.
|
||||
If truncate, values will be bootstrapped from the `done` step.
|
||||
|
||||
The final step of each sequence is *NOT* overlapped by the first step of the next sequence.
|
||||
Each sequence will be complete. The last step should only have observations but no rewards.
|
||||
This is used in SRL. Therefore, this function is suffixed with "nolp" (non-overlap).
|
||||
|
||||
Args:
|
||||
rewards (torch.FloatTensor): Shape [batch_size, seqlen].
|
||||
values (torch.FloatTensor): Shape [batch_size, seqlen + 1], with one more bootstrap step.
|
||||
dones (torch.BoolTensor): Shape [batch_size, seqlen + 1].
|
||||
truncates (torch.BoolTensor): Shape [batch_size, seqlen + 1].
|
||||
gamma (float): Discount factor.
|
||||
lam (float): GAE discount factor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Advantages and returns (value targets).
|
||||
Both have the same shape as rewards.
|
||||
"""
|
||||
import realhf._C.cugae as gae_cuda
|
||||
|
||||
dones = on_reset[:, 1:]
|
||||
truncates_indices = truncates[:, :-1].nonzero()
|
||||
assert torch.all(dones[truncates_indices[:, 0], truncates_indices[:, 1]])
|
||||
on_reset_indices = on_reset.nonzero()
|
||||
num_resets = on_reset.float().sum(1)
|
||||
max_num_resets = int(num_resets.max())
|
||||
cu_num_resets = torch.nn.functional.pad(num_resets.cumsum(0), (1, 0), value=0).int()
|
||||
truncates = torch.cat(
|
||||
[torch.zeros_like(truncates[:, 0:1]), truncates[:, :-1]], dim=1
|
||||
)
|
||||
bootstrap = truncates[on_reset_indices[:, 0], on_reset_indices[:, 1]]
|
||||
return gae_cuda.gae_2d_nolp(
|
||||
rewards,
|
||||
values,
|
||||
on_reset,
|
||||
on_reset_indices[:, 1].int(),
|
||||
cu_num_resets,
|
||||
max_num_resets,
|
||||
bootstrap,
|
||||
gamma,
|
||||
lam,
|
||||
)
|
||||
return gae_1d_nolp_misalign(rewards, values, cu_seqlens, truncate, gamma, lam)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -9,11 +9,7 @@ import torch
|
|||
|
||||
from realhf.impl.model.utils.ppo_functional import (
|
||||
cugae1d_nolp_misalign_func,
|
||||
cugae2d_nolp_func,
|
||||
cugae2d_olp_func,
|
||||
pygae1d_nolp_misalign,
|
||||
pygae2d_nolp,
|
||||
pygae2d_olp,
|
||||
)
|
||||
|
||||
|
||||
|
@ -58,78 +54,3 @@ def test_gae1d_nolp_misalign(max_seqlen: int, bs: int, gamma: float, lam: float)
|
|||
f"max_seqlen={max_seqlen},bs={bs}, CUDA acceleration ratio",
|
||||
(t2 - t1) / (t3 - t2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
|
||||
@pytest.mark.parametrize("seqlen", [32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("bs", [8, 16, 32, 100])
|
||||
@pytest.mark.parametrize("gamma", [0.9, 1.0])
|
||||
@pytest.mark.parametrize("lam", [0.5, 1.0])
|
||||
def test_gae2d_olp(bs: int, seqlen: int, gamma: float, lam: float):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
rewards = torch.randn(bs, seqlen).cuda()
|
||||
values = torch.randn(bs, seqlen + 1).cuda()
|
||||
dones = torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda()
|
||||
truncates = dones.logical_and(torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda())
|
||||
|
||||
py_adv, py_ret = pygae2d_olp(rewards, values, dones, truncates, gamma, lam)
|
||||
adv, ret = cugae2d_olp_func(rewards, values, dones, truncates, gamma, lam)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter_ns()
|
||||
py_adv, py_ret = pygae2d_olp(rewards, values, dones, truncates, gamma, lam)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter_ns()
|
||||
adv, ret = cugae2d_olp_func(rewards, values, dones, truncates, gamma, lam)
|
||||
torch.cuda.synchronize()
|
||||
t3 = time.perf_counter_ns()
|
||||
|
||||
assert torch.allclose(adv, py_adv, atol=1e-5), (adv - py_adv).abs().max()
|
||||
assert torch.allclose(ret, py_ret, atol=1e-5), (ret - py_ret).abs().max()
|
||||
print(
|
||||
f"seqlen={seqlen},bs={bs}, CUDA acceleration ratio",
|
||||
(t2 - t1) / (t3 - t2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
|
||||
@pytest.mark.parametrize("seqlen", [32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("bs", [8, 16, 32, 100])
|
||||
@pytest.mark.parametrize("gamma", [0.9, 1.0])
|
||||
@pytest.mark.parametrize("lam", [0.5, 1.0])
|
||||
def test_gae2d_nolp(bs: int, seqlen: int, gamma: float, lam: float):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
rewards = torch.randn(bs, seqlen).cuda()
|
||||
values = torch.randn(bs, seqlen + 1).cuda()
|
||||
on_reset_ = torch.randint(0, 2, (bs, seqlen + 2)).bool().cuda()
|
||||
on_reset = on_reset_[:, :-1].contiguous()
|
||||
truncates = (
|
||||
on_reset_[:, 1:]
|
||||
.logical_and(torch.randint(0, 2, (bs, seqlen + 1)).bool().cuda())
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
py_adv, py_ret = pygae2d_nolp(rewards, values, on_reset, truncates, gamma, lam)
|
||||
adv, ret = cugae2d_nolp_func(rewards, values, on_reset, truncates, gamma, lam)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter_ns()
|
||||
py_adv, py_ret = pygae2d_nolp(rewards, values, on_reset, truncates, gamma, lam)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter_ns()
|
||||
adv, ret = cugae2d_nolp_func(rewards, values, on_reset, truncates, gamma, lam)
|
||||
torch.cuda.synchronize()
|
||||
t3 = time.perf_counter_ns()
|
||||
|
||||
adv = adv * (1 - truncates[:, :-1].float())
|
||||
py_adv = py_adv * (1 - truncates[:, :-1].float())
|
||||
ret = ret * (1 - truncates[:, :-1].float())
|
||||
py_ret = py_ret * (1 - truncates[:, :-1].float())
|
||||
assert torch.allclose(adv, py_adv, atol=1e-5), (adv - py_adv).abs().max()
|
||||
assert torch.allclose(ret, py_ret, atol=1e-5), (ret - py_ret).abs().max()
|
||||
print(
|
||||
f"seqlen={seqlen},bs={bs}, CUDA acceleration ratio",
|
||||
(t2 - t1) / (t3 - t2),
|
||||
)
|
||||
|
|
|
@ -112,7 +112,10 @@ def test_fn(
|
|||
|
||||
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
|
||||
|
||||
backend = SGLangGenerationBackend(model_path=path)
|
||||
backend = SGLangGenerationBackend(
|
||||
model_path=path,
|
||||
dtype="bfloat16" if module.dtype == torch.bfloat16 else torch.float16,
|
||||
)
|
||||
model = model_api.Model(
|
||||
name=model_name,
|
||||
module=module,
|
||||
|
@ -188,7 +191,6 @@ def test_fn(
|
|||
|
||||
print("success")
|
||||
|
||||
# 清理分布式环境
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue