mirror of https://github.com/inclusionAI/AReaL
[Fix] [lite] Merge from the internal repo to fix GRPO bugs and refactor the train engine (#181)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
0283cfa124
commit
29e164a69d
|
@ -18,7 +18,7 @@ from arealite.utils.fs import get_user_tmp
|
|||
class MicroBatchSpec:
|
||||
"""Specification for splitting micro-batches during training."""
|
||||
|
||||
n_mbs: int = field(
|
||||
n_mbs: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
||||
|
@ -161,7 +161,7 @@ class FSDPEngineConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class HFEngineConfig:
|
||||
class DeepSpeedAutoTPEngineConfig:
|
||||
autotp_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "DeepSpeed AutoTP size"},
|
||||
|
@ -201,7 +201,88 @@ class TrainEngineConfig:
|
|||
)
|
||||
backend: str = ""
|
||||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
|
||||
ds_auto_tp: DeepSpeedAutoTPEngineConfig = field(
|
||||
default_factory=DeepSpeedAutoTPEngineConfig
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOActorConfig(TrainEngineConfig):
|
||||
# Core PPO/GRPO Parameters
|
||||
group_size: int = field(
|
||||
default=1, metadata={"help": "Number of sequences in each group"}
|
||||
)
|
||||
group_adv_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize advantages within each prompt group rather than globally"
|
||||
},
|
||||
)
|
||||
ppo_n_minibatches: int = field(
|
||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||
)
|
||||
eps_clip: float = field(
|
||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||
)
|
||||
c_clip: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "Temperature during generation."}
|
||||
)
|
||||
# Reward
|
||||
group_reward_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||
},
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||
)
|
||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||
reward_clip: float = field(
|
||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||
)
|
||||
mask_no_eos_with_zero: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||
},
|
||||
)
|
||||
|
||||
# Advantage Estimation
|
||||
discount: float = field(
|
||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||
)
|
||||
gae_lambda: float = field(
|
||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||
|
||||
# Asynchronous RL
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -25,10 +25,10 @@ class Scheduling:
|
|||
cpu: int
|
||||
gpu: int
|
||||
mem: int
|
||||
nodelist: str = None
|
||||
exclude: str = None
|
||||
partition: str = None
|
||||
container_image: str = None
|
||||
nodelist: Optional[str] = None
|
||||
exclude: Optional[str] = None
|
||||
partition: Optional[str] = None
|
||||
container_image: Optional[str] = None
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
# time utils from "https://slurm.schedmd.com/sbatch.html"
|
||||
time_limit: Optional[str] = None # see "--time" option for format
|
||||
|
@ -105,7 +105,7 @@ class TrainEngine(abc.ABC):
|
|||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
|
|
|
@ -71,7 +71,7 @@ class AllocationType(enum.Enum):
|
|||
@dataclass
|
||||
class AllocationMode:
|
||||
type_: AllocationType
|
||||
parallel_strat: None | Dict[str, Dict[str, int]]
|
||||
parallel_strat: Dict[str, Dict[str, int]]
|
||||
|
||||
@property
|
||||
def gen_tp_size(self) -> int:
|
||||
|
@ -115,7 +115,7 @@ class AllocationMode:
|
|||
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
|
||||
|
||||
@staticmethod
|
||||
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
|
||||
def extract_parallelism_strategy(allocation_mode: str) -> Dict:
|
||||
for x, y, z in itertools.permutations(["d", "t", "p"]):
|
||||
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
||||
m = re.match(pattern, allocation_mode)
|
||||
|
@ -130,29 +130,28 @@ class AllocationMode:
|
|||
z: c,
|
||||
}
|
||||
}
|
||||
raise ValueError(
|
||||
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
||||
pattern = re.compile(
|
||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||
)
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
return
|
||||
raise ValueError(
|
||||
f"Unknown how to resolve decoupled allocation: {allocation_mode}"
|
||||
)
|
||||
if m.group(1):
|
||||
gen_alloc = m.group(1)
|
||||
other_alloc = m.group(2)
|
||||
else:
|
||||
gen_alloc = m.group(4)
|
||||
other_alloc = m.group(3)
|
||||
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
|
||||
if not gen_alloc:
|
||||
return
|
||||
other_alloc = AllocationMode.extract_3d_alloc(
|
||||
other_alloc
|
||||
) or AllocationMode.extract_key_value_alloc(other_alloc)
|
||||
if not other_alloc:
|
||||
return
|
||||
gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc)
|
||||
other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc)
|
||||
other_alloc.update({"gen": gen_alloc["*"]})
|
||||
return other_alloc
|
||||
|
||||
|
@ -171,7 +170,7 @@ class SaveLoadMeta:
|
|||
path: str
|
||||
weight_format: str
|
||||
with_optim: bool
|
||||
tokenizer: PreTrainedTokenizerFast | None
|
||||
tokenizer: Optional[PreTrainedTokenizerFast]
|
||||
base_model_path: str | None
|
||||
naive_distributed: bool = False
|
||||
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||
from arealite.utils.save_load import (
|
||||
get_state_dict_from_repo_id_or_path,
|
||||
is_existing_local_path,
|
||||
)
|
||||
from realhf.base import constants, logging
|
||||
|
||||
logger = logging.getLogger("DeepSpeedAutoTPEngine")
|
||||
|
||||
|
||||
class DeepSpeedAutoTPEngine(BaseHFEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
"""Initialize distributed communication and model."""
|
||||
assert (
|
||||
addr is None
|
||||
), "DeepSpeedAutoTPEngine does not support remote initialization."
|
||||
import deepspeed
|
||||
|
||||
self.create_process_group()
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE"))
|
||||
deepspeed.init_distributed(
|
||||
dist_backend="nccl",
|
||||
world_size=world_size,
|
||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||
)
|
||||
self.create_device_model()
|
||||
# NOTE: the device context manager does not work here.
|
||||
self.model = deepspeed.tp_model_init(
|
||||
self.model,
|
||||
tp_size=self.config.ds_auto_tp.autotp_size,
|
||||
dtype=getattr(torch, self.config.dtype),
|
||||
).to(self.device)
|
||||
self.create_optimizer(ft_spec)
|
||||
self.initialized = True
|
||||
|
||||
def _check_autotp(self):
|
||||
tp_size = self.config.ds_auto_tp.autotp_size
|
||||
config = self.model_config
|
||||
num_attention_heads = config.num_attention_heads
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
return (
|
||||
num_attention_heads % tp_size == 0
|
||||
and num_key_value_heads % tp_size == 0
|
||||
and hidden_size % tp_size == 0
|
||||
and intermediate_size % tp_size == 0
|
||||
)
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format != "naive_distributed":
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
if rank == 0:
|
||||
os.makedirs(meta.path, exist_ok=True)
|
||||
self.model_config.save_pretrained(
|
||||
meta.path,
|
||||
)
|
||||
if meta.tokenizer is not None:
|
||||
meta.tokenizer.save_pretrained(
|
||||
meta.path,
|
||||
)
|
||||
|
||||
state_dict = self.model.state_dict()
|
||||
if hasattr(self.model, "module"):
|
||||
state_dict = {
|
||||
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||||
|
||||
# Only support store parameters from model partitions respectively
|
||||
gathered_state_dicts = None
|
||||
if rank == 0:
|
||||
gathered_state_dicts = [None for _ in range(world_size)]
|
||||
dist.gather_object(
|
||||
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
||||
)
|
||||
if rank == 0:
|
||||
for i, state_dict in enumerate(gathered_state_dicts):
|
||||
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
|
||||
if meta.with_optim:
|
||||
self.save_optimizer_state(meta.path)
|
||||
|
||||
def load(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format != "naive_distributed":
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
rank = dist.get_rank()
|
||||
# Only support load full model parameters from huggingface
|
||||
# and load model partition locally
|
||||
if rank == 0 or is_existing_local_path(meta.path):
|
||||
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
|
||||
full_state = get_state_dict_from_repo_id_or_path(meta.path)
|
||||
|
||||
if hasattr(self.model, "module") and not hasattr(full_state):
|
||||
full_state = {
|
||||
f"module.{k}" if not k.startswith("module.") else k: v
|
||||
for k, v in full_state.items()
|
||||
}
|
||||
self.model.load_state_dict(
|
||||
full_state, strict=not self.model_config.tie_word_embeddings
|
||||
)
|
||||
if self.model_config.tie_word_embeddings:
|
||||
self.model.tie_weights()
|
||||
|
||||
if meta.with_optim:
|
||||
self.load_optimizer_state(meta.path)
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
raise ValueError(f"update weight not implemented {meta.type}")
|
|
@ -0,0 +1,360 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, TrainEngine
|
||||
from arealite.utils.data import (
|
||||
MicroBatchList,
|
||||
amend_position_ids,
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import constants, logging
|
||||
|
||||
logger = logging.getLogger("Base HF Engine")
|
||||
|
||||
|
||||
class BaseHFEngine(TrainEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
self.config = config
|
||||
self.optimizer_config = config.optimizer
|
||||
|
||||
self.model: torch.nn.Module
|
||||
self.optimizer: torch.optim.Optimizer
|
||||
self.tokenizer: PreTrainedTokenizerFast
|
||||
# huggingface model config
|
||||
self.model_config: PretrainedConfig
|
||||
|
||||
# initialization
|
||||
self.initialized = False
|
||||
self.own_global_group = False
|
||||
self._parallelism_group: dist.ProcessGroup
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
@property
|
||||
def parallelism_group(self) -> dist.ProcessGroup:
|
||||
assert self.initialized
|
||||
return self._parallelism_group
|
||||
|
||||
def create_process_group(self):
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
||||
)
|
||||
self.own_global_group = True
|
||||
self._parallelism_group = dist.new_group()
|
||||
|
||||
def create_device_model(self):
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
self.model = model
|
||||
|
||||
def create_optimizer(self, ft_spec: FinetuneSpec):
|
||||
if self.optimizer_config is None:
|
||||
return
|
||||
assert self.model is not None
|
||||
# Set up optimizer
|
||||
tik = time.perf_counter()
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = self.optimizer_config.lr
|
||||
weight_decay = self.optimizer_config.weight_decay
|
||||
beta1 = self.optimizer_config.beta1
|
||||
beta2 = self.optimizer_config.beta2
|
||||
eps = self.optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
del self.optimizer
|
||||
del self.model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
dist.destroy_process_group(self.parallelism_group)
|
||||
if self.own_global_group:
|
||||
dist.destroy_process_group()
|
||||
self.initialized = False
|
||||
|
||||
def save_optimizer_state(self, path: str):
|
||||
# Save FSDP sharded state dict on each rank
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, shard_path)
|
||||
dist.barrier()
|
||||
|
||||
def load_optimizer_state(self, path: str):
|
||||
# Load FSDP sharded state dict
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
dist.barrier()
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
for mb in mb_list.padded_mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
input_ = input_.to(self.device)
|
||||
assert self.optimizer is not None
|
||||
assert self.optimizer_config is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self.prepare_mb_list(input_)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
|
||||
# Scale loss for accumulation
|
||||
# Revert gradient averaging across dp ranks
|
||||
loss_scale *= self.world_size
|
||||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.optimizer_config.gradient_clipping,
|
||||
norm_type=2.0,
|
||||
error_if_nonfinite=False,
|
||||
foreach=None,
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
self.optimizer.zero_grad()
|
||||
update_successful = False
|
||||
else:
|
||||
self.optimizer.step()
|
||||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
lr=current_lr,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
input_ = input_.to(self.device)
|
||||
mb_list = self.prepare_mb_list(input_)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
input_ = input_.to(self.device)
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self.prepare_mb_list(input_)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
if post_hook:
|
||||
result = post_hook(logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
@ -1,42 +1,20 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from tensordict import TensorDict
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
SaveLoadMeta,
|
||||
TrainEngine,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.data import (
|
||||
MicroBatchList,
|
||||
amend_position_ids,
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||
from arealite.utils.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
|
@ -44,108 +22,35 @@ from arealite.utils.fsdp import (
|
|||
create_fsdp_device_mesh,
|
||||
fsdp2_clip_grad_norm_,
|
||||
fsdp2_load_full_state_dict,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import constants, logging, name_resolve, names, pkg_version
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
|
||||
class FSDPEngine(TrainEngine):
|
||||
class FSDPEngine(BaseHFEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
self.config = config
|
||||
self.optimizer_config = config.optimizer
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.tokenizer = None
|
||||
# huggingface model config
|
||||
self.model_config = None
|
||||
super().__init__(config)
|
||||
# FSDP options
|
||||
self.mixed_precision_policy = None
|
||||
self.device_mesh = None
|
||||
self.cpu_offload = None
|
||||
# initialization
|
||||
self.initialized = False
|
||||
self.own_global_group = False
|
||||
self._parallelism_group = None
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
# TODO: Handle the case when WORLD_SIZE is not set in launcher
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
@property
|
||||
def parallelism_group(self) -> dist.ProcessGroup:
|
||||
assert self.initialized
|
||||
return self._parallelism_group
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
# Initialize distributed enviroments and load model.
|
||||
assert addr is None, "FSDPEngine does not support remote initialization."
|
||||
|
||||
assert pkg_version.is_version_greater_or_equal(
|
||||
"torch", "2.4.0"
|
||||
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
||||
|
||||
"""Initialize distributed communication and model."""
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
||||
)
|
||||
self.own_global_group = True
|
||||
self._parallelism_group = dist.new_group()
|
||||
|
||||
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
self.create_process_group()
|
||||
self.create_device_model()
|
||||
|
||||
# Wrap with FSDP2
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=dtype,
|
||||
param_dtype=getattr(torch, self.config.dtype),
|
||||
reduce_dtype=torch.float32,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
@ -154,82 +59,19 @@ class FSDPEngine(TrainEngine):
|
|||
self.cpu_offload = (
|
||||
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
|
||||
)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"mesh": self.device_mesh,
|
||||
"mp_policy": self.mixed_precision_policy,
|
||||
"offload_policy": self.cpu_offload,
|
||||
"reshard_after_forward": True,
|
||||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
tik = time.perf_counter()
|
||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
||||
self.model = model
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
tik = time.perf_counter()
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = self.optimizer_config.lr
|
||||
weight_decay = self.optimizer_config.weight_decay
|
||||
beta1 = self.optimizer_config.beta1
|
||||
beta2 = self.optimizer_config.beta2
|
||||
eps = self.optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||
|
||||
self.create_optimizer(ft_spec)
|
||||
self.initialized = True
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
dist.destroy_process_group(self.parallelism_group)
|
||||
if self.own_global_group:
|
||||
dist.destroy_process_group()
|
||||
self.initialized = False
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._save_model_to_hf(meta.path, meta.tokenizer)
|
||||
|
@ -240,7 +82,7 @@ class FSDPEngine(TrainEngine):
|
|||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._save_optimizer_state(meta.path)
|
||||
self.save_optimizer_state(meta.path)
|
||||
|
||||
def load(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
|
@ -252,34 +94,10 @@ class FSDPEngine(TrainEngine):
|
|||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._load_optimizer_state(meta.path)
|
||||
|
||||
def _save_optimizer_state(self, path: str):
|
||||
# Save FSDP sharded state dict on each rank
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, shard_path)
|
||||
dist.barrier()
|
||||
|
||||
def _load_optimizer_state(self, path: str):
|
||||
# Load FSDP sharded state dict
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
dist.barrier()
|
||||
self.load_optimizer_state(meta.path)
|
||||
|
||||
def _save_model_to_hf(
|
||||
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
|
||||
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
|
@ -345,35 +163,11 @@ class FSDPEngine(TrainEngine):
|
|||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
input_ = input_.to(self.device)
|
||||
|
@ -382,7 +176,7 @@ class FSDPEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
mb_list = self.prepare_mb_list(input_)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
|
@ -394,7 +188,6 @@ class FSDPEngine(TrainEngine):
|
|||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
|
@ -409,6 +202,7 @@ class FSDPEngine(TrainEngine):
|
|||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
# NOTE: grad norm clip function is different
|
||||
grad_norm = fsdp2_clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
||||
)
|
||||
|
@ -427,72 +221,3 @@ class FSDPEngine(TrainEngine):
|
|||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
lr=current_lr,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
input_ = input_.to(self.device)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
input_ = input_.to(self.device)
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
if post_hook:
|
||||
result = post_hook(logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
||||
|
|
|
@ -1,467 +0,0 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from safetensors.torch import save_file
|
||||
from tensordict import TensorDict
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
SaveLoadMeta,
|
||||
TrainEngine,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.data import (
|
||||
MicroBatchList,
|
||||
amend_position_ids,
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||
from arealite.utils.save_load import (
|
||||
get_state_dict_from_repo_id_or_path,
|
||||
is_existing_local_path,
|
||||
)
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
logger = logging.getLogger("HFEngine")
|
||||
|
||||
|
||||
class HFEngine(TrainEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
self.config = config
|
||||
self.optimizer_config = config.optimizer
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.tokenizer = None
|
||||
# huggingface model config
|
||||
self.model_config = None
|
||||
# initialization
|
||||
self.initialized = False
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
"""Initialize distributed communication and model."""
|
||||
assert addr is None, "HFEngine does not support remote initialization."
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 0))
|
||||
if not dist.is_initialized() and world_size > 1:
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError:
|
||||
print(
|
||||
"Warning: deepspeed is not installed. Some functionality may be disabled."
|
||||
)
|
||||
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
).to(f"cuda:{local_rank}")
|
||||
|
||||
if not self.config.init_from_scratch:
|
||||
# Load model from a initial checkpoint path,
|
||||
# which should only be a huggingface checkpoint.
|
||||
load_meta = SaveLoadMeta(
|
||||
path=self.config.path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=None,
|
||||
base_model_path=self.config.path,
|
||||
naive_distributed=False,
|
||||
)
|
||||
|
||||
self.load(load_meta)
|
||||
|
||||
if world_size > 1:
|
||||
if self._check_autotp():
|
||||
self.model = deepspeed.tp_model_init(
|
||||
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = self.optimizer_config.lr
|
||||
weight_decay = self.optimizer_config.weight_decay
|
||||
beta1 = self.optimizer_config.beta1
|
||||
beta2 = self.optimizer_config.beta2
|
||||
eps = self.optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def _check_autotp(self):
|
||||
tp_size = self.config.hf.autotp_size
|
||||
config = self.model_config
|
||||
num_attention_heads = config.num_attention_heads
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
return (
|
||||
num_attention_heads % tp_size == 0
|
||||
and num_key_value_heads % tp_size == 0
|
||||
and hidden_size % tp_size == 0
|
||||
and intermediate_size % tp_size == 0
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
self.initialized = False
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
|
||||
elif meta.weight_format == "dcp":
|
||||
# TODO: implement DCP save/load for HF
|
||||
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._save_optimizer_state(meta.path)
|
||||
|
||||
def load(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._load_model_from_hf(meta.path, meta.naive_distributed)
|
||||
elif meta.weight_format == "dcp":
|
||||
# TODO: implement DCP save/load for HF
|
||||
raise NotImplementedError("DCP format loading is not implemented yet. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._load_optimizer_state(meta.path)
|
||||
|
||||
def _save_optimizer_state(self, path: str):
|
||||
assert self.optimizer is not None
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
|
||||
|
||||
def _load_optimizer_state(self, path: str):
|
||||
assert self.optimizer is not None
|
||||
path = os.path.join(path, "optim.pt")
|
||||
optimizer_state_dict = torch.load(path, weights_only=False)
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
|
||||
def _save_model_to_hf(
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
|
||||
naive_distributed: bool,
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
if rank == 0:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if hasattr(self.model, "module"):
|
||||
state_dict = {
|
||||
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||||
|
||||
if world_size > 1 and naive_distributed:
|
||||
# Only support store parameters from model partitions respectively
|
||||
gathered_state_dicts = None
|
||||
if rank == 0:
|
||||
gathered_state_dicts = [None for _ in range(world_size)]
|
||||
|
||||
dist.gather_object(
|
||||
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
for i, state_dict in enumerate(gathered_state_dicts):
|
||||
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
|
||||
else:
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
def _load_model_from_hf(self, path: str, naive_distributed: bool):
|
||||
"""Load model from HuggingFace format."""
|
||||
|
||||
rank = dist.get_rank()
|
||||
# Only support load full model parameters from huggingface
|
||||
# and load model partition locally
|
||||
if rank == 0 or is_existing_local_path(path):
|
||||
if naive_distributed:
|
||||
path = f"{path}/rank_{rank:02d}_model.safetensors"
|
||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||
|
||||
if hasattr(self.model, "module") and not hasattr(full_state):
|
||||
full_state = {
|
||||
f"module.{k}" if not k.startswith("module.") else k: v
|
||||
for k, v in full_state.items()
|
||||
}
|
||||
self.model.load_state_dict(
|
||||
full_state, strict=not self.model_config.tie_word_embeddings
|
||||
)
|
||||
|
||||
if self.model_config.tie_word_embeddings:
|
||||
self.model.tie_weights()
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "nccl":
|
||||
if not self.weight_update_group_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
self._update_weights_from_distributed()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
|
||||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
)
|
||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||
else:
|
||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for HFEngine yet. "
|
||||
)
|
||||
|
||||
def _update_weights_from_distributed(self):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for HFEngine yet. "
|
||||
)
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
return self.lr_scheduler.step()
|
||||
|
||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
self.config.mb_spec,
|
||||
)
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
input_ = input_.to(self.device)
|
||||
assert self.optimizer is not None
|
||||
assert self.optimizer_config is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.optimizer_config.gradient_clipping,
|
||||
norm_type=2.0,
|
||||
error_if_nonfinite=False,
|
||||
foreach=None,
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
self.optimizer.zero_grad()
|
||||
update_successful = False
|
||||
else:
|
||||
self.optimizer.step()
|
||||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
lr=current_lr,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
if post_hook:
|
||||
result = post_hook(logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
@ -114,7 +114,9 @@ class PPOActor:
|
|||
values = torch.zeros_like(rewards)
|
||||
else:
|
||||
values = data["values"]
|
||||
advantages_reversed = []
|
||||
advantages_reversed = [
|
||||
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
||||
]
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(max_seqlen - 1)):
|
||||
nextvalues = values[:, t + 1]
|
||||
|
@ -123,9 +125,6 @@ class PPOActor:
|
|||
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
||||
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages_reversed.append(
|
||||
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
||||
)
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
# Optionally perform advantage normalization.
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tensordict import TensorDict
|
||||
|
@ -44,9 +42,7 @@ class FSDPLMEngine(FSDPEngine):
|
|||
return self.lm_engine.evaluate_lm(data)
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
loss_mask = input_["loss_mask"].bool()
|
||||
|
|
|
@ -7,10 +7,12 @@ import traceback
|
|||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
import uvloop
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
|
@ -23,8 +25,8 @@ from arealite.api.io_struct import (
|
|||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.http import arequest_with_retry
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -37,7 +39,7 @@ if pkg_version.is_available("sglang"):
|
|||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.1
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
|
@ -98,6 +100,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
self.executor = ProcessPoolExecutor(max_workers=1)
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
@ -118,32 +121,39 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
def _rollout_thread(self):
|
||||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
asyncio.run(self._rollout_thread_async())
|
||||
uvloop.run(self._rollout_thread_async())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
pending_data = []
|
||||
rollout_tasks = self.rollout_tasks
|
||||
rid = 0
|
||||
|
||||
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.request_timeout,
|
||||
sock_connect=self.config.request_timeout,
|
||||
connect=self.config.request_timeout,
|
||||
),
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
connector=get_default_connector(),
|
||||
)
|
||||
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Load next data from controller
|
||||
while True:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
pending_data.append(data)
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
break
|
||||
|
||||
# Check capacity
|
||||
capacity = self.get_capacity()
|
||||
# Create new rollout task
|
||||
while capacity > 0 and pending_data and not self.paused.is_set():
|
||||
while (
|
||||
capacity > 0
|
||||
and not self.paused.is_set()
|
||||
and self.input_queue.qsize() > 0
|
||||
):
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
)
|
||||
with self.lock:
|
||||
rollout_tasks[str(rid)] = task
|
||||
|
@ -158,7 +168,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
|
||||
# Wait for rollout completion
|
||||
with self.lock:
|
||||
tasks = list(rollout_tasks.values())
|
||||
|
@ -169,11 +178,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if not done:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
|
@ -199,6 +203,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
|
@ -213,10 +218,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
pass
|
||||
|
||||
def choose_server(self) -> str:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
server = self.addresses[self.server_idx]
|
||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||
return server
|
||||
with self.lock:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
server = self.addresses[self.server_idx]
|
||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
|
@ -253,7 +259,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
accumulated_versions = []
|
||||
|
||||
# Deal with rollout interruption
|
||||
completions = ""
|
||||
stop_reason = "length"
|
||||
|
||||
if req.rid in self.rid_to_address:
|
||||
|
@ -273,11 +278,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
# loop until the generation is complete
|
||||
result = await arequest_with_retry(
|
||||
addr=self.choose_server(),
|
||||
session=self.session,
|
||||
addr=server_addr,
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
max_retries=self.config.request_retries,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
|
||||
|
@ -297,10 +303,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
sample_params["max_new_tokens"] = min(
|
||||
sample_params["max_new_tokens"],
|
||||
gconfig.max_new_tokens - len(output_tokens),
|
||||
)
|
||||
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
|
@ -408,18 +411,24 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
def prepare_batch(
|
||||
self,
|
||||
data_generator: Iterator,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
if not hasattr(self, "data_generator"):
|
||||
self.data_generator = iter(dataloader)
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
if self.get_capacity() + dataloader.batch_size > 0:
|
||||
# Submit at least two batches to allow maximum overlap
|
||||
if (
|
||||
self.get_capacity() + dataloader.batch_size > 0
|
||||
and self.input_queue.qsize() + dataloader.batch_size
|
||||
< self.input_queue.maxsize
|
||||
):
|
||||
try:
|
||||
data = next(data_generator)
|
||||
data = next(self.data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(dataloader)
|
||||
data = next(data_generator)
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
for item in data:
|
||||
self.submit(item, workflow=workflow)
|
||||
try:
|
||||
|
@ -435,10 +444,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
|
||||
async def aupdate_weights_from_disk(
|
||||
addr, path: str, request_retries: int, request_timeout: float
|
||||
session, addr, path: str, request_retries: int, request_timeout: float
|
||||
):
|
||||
tik = time.time()
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
session=session,
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
|
@ -472,9 +483,19 @@ def update_weights_from_disk(
|
|||
logger.info(
|
||||
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
||||
)
|
||||
session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=request_timeout,
|
||||
sock_connect=request_timeout,
|
||||
connect=request_timeout,
|
||||
),
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
connector=get_default_connector(),
|
||||
)
|
||||
jobs = [
|
||||
aupdate_weights_from_disk(
|
||||
addr,
|
||||
session=session,
|
||||
addr=addr,
|
||||
path=path,
|
||||
request_retries=request_retries,
|
||||
request_timeout=request_timeout,
|
||||
|
@ -482,8 +503,9 @@ def update_weights_from_disk(
|
|||
for addr in addresses
|
||||
]
|
||||
await asyncio.gather(*jobs)
|
||||
await session.close()
|
||||
logger.info(
|
||||
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
||||
)
|
||||
|
||||
return asyncio.run(_fn())
|
||||
return uvloop.run(_fn())
|
||||
|
|
|
@ -5,6 +5,7 @@ import time
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
@ -28,6 +29,17 @@ HOST = network.gethostip()
|
|||
RUN_SERVER_TIMEOUT = 180
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server():
|
||||
from realhf.base import seeding
|
||||
|
|
|
@ -67,7 +67,6 @@ async def test_local_sglang_generate():
|
|||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
print(resp.completions)
|
||||
|
||||
assert isinstance(resp, LLMResponse)
|
||||
assert resp.input_tokens == req.input_ids
|
||||
|
@ -76,9 +75,6 @@ async def test_local_sglang_generate():
|
|||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completions, str)
|
||||
|
||||
time.sleep(5)
|
||||
engine.destroy()
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ from transformers import AutoTokenizer
|
|||
|
||||
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
|
||||
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
|
@ -53,10 +52,10 @@ def mock_input(
|
|||
|
||||
|
||||
def get_engine(engine_type: str, model_path: str):
|
||||
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.engine.hf_engine import HFEngine
|
||||
|
||||
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type]
|
||||
engine_cls = {"auto_tp": DeepSpeedAutoTPEngine, "fsdp": FSDPEngine}[engine_type]
|
||||
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name=f"test-{engine_type}-engine",
|
||||
|
@ -75,7 +74,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
|||
return torch.mean(logits)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["fsdp", "hf"])
|
||||
@pytest.fixture(scope="module", params=["fsdp", "auto_tp"])
|
||||
def engine(request):
|
||||
os.environ.update(
|
||||
{
|
||||
|
@ -136,6 +135,12 @@ def test_train_batch(engine, mock_input):
|
|||
|
||||
@torch.no_grad()
|
||||
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
||||
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
|
||||
|
||||
if isinstance(engine, DeepSpeedAutoTPEngine):
|
||||
print("AutoTP engine does not support HF save/load for now.")
|
||||
return
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
path = tmp_path_factory.mktemp("hf_engine_test")
|
||||
save_load_meta = SaveLoadMeta(
|
|
@ -92,9 +92,7 @@ def unpad_input(
|
|||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
||||
indices,
|
||||
|
@ -116,16 +114,12 @@ def concat_padded_tensors(
|
|||
if not tensor_dicts:
|
||||
return TensorDict()
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
lens = []
|
||||
for tensor_dict in tensor_dicts:
|
||||
for key, tensor in tensor_dict.items():
|
||||
if key != "attention_mask" and len(tensor.shape) == 2:
|
||||
lens.append(tensor.shape[1])
|
||||
break
|
||||
max_length = max(lens)
|
||||
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
|
||||
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
|
||||
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
assert all("attention_mask" in td for td in tensor_dicts)
|
||||
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
||||
result = {}
|
||||
# Process each key
|
||||
for key in tensor_dicts[0].keys():
|
||||
|
@ -154,9 +148,7 @@ def concat_padded_tensors(
|
|||
tensors_to_concat.append(tensor)
|
||||
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
if "attention_mask" not in result:
|
||||
result["attention_mask"] = attn_mask
|
||||
return TensorDict(result, batch_size=[len(lens)])
|
||||
return TensorDict(result, batch_size=new_batch_size)
|
||||
|
||||
|
||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||
|
@ -231,7 +223,7 @@ def pack_tensor_dict(data: TensorDict):
|
|||
cu_seqlens = torch.cumsum(lens, dim=0)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
total_length = cu_seqlens[-1].item()
|
||||
total_length = int(cu_seqlens[-1].item())
|
||||
# Pack tensors
|
||||
packed_data = {}
|
||||
for key, value in data.items():
|
||||
|
@ -246,7 +238,7 @@ def pack_tensor_dict(data: TensorDict):
|
|||
and value.shape[1] == seq_len
|
||||
):
|
||||
packed_tensor = torch.empty(
|
||||
total_length, *value.shape[2:], dtype=value.dtype, device=value.device
|
||||
(total_length, *value.shape[2:]), dtype=value.dtype, device=value.device
|
||||
)
|
||||
# Fill the packed tensor with values from the original tensor
|
||||
for i in range(bs):
|
||||
|
@ -363,7 +355,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
mbs=results,
|
||||
mb_spec=mb_spec,
|
||||
forward_indices=forward_indices,
|
||||
backward_indices=backward_indices,
|
||||
backward_indices=backward_indices.tolist(),
|
||||
group_lens=group_lens,
|
||||
)
|
||||
|
||||
|
|
|
@ -121,19 +121,24 @@ def masked_normalization(
|
|||
|
||||
def ppo_actor_loss_fn(
|
||||
logprobs: torch.Tensor,
|
||||
proximal_logprobs: torch.Tensor,
|
||||
old_logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
||||
behav_imp_weight_cap: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
denorm_logprobs = (
|
||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
)
|
||||
"""
|
||||
When decoupled loss is disabled:
|
||||
1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp;
|
||||
2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend.
|
||||
|
||||
When decoupled loss is enabled, proximal_logprobs is the recomputed logp,
|
||||
old_logprobs is produced by the inference engine.
|
||||
"""
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
pg_loss2 = -advantages * clipped_ratio
|
||||
|
@ -146,17 +151,16 @@ def ppo_actor_loss_fn(
|
|||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
logging_loss = pg_loss.detach()
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
|
@ -164,7 +168,7 @@ def ppo_actor_loss_fn(
|
|||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
approx_kl=(logprobs - proximal_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
|
|
|
@ -3,45 +3,74 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import aiohttp
|
||||
|
||||
from realhf.base import logging
|
||||
|
||||
DEFAULT_RETRIES = 1
|
||||
DEFAULT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def get_default_connector():
|
||||
return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True)
|
||||
|
||||
|
||||
async def arequest_with_retry(
|
||||
addr: str,
|
||||
endpoint: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
method: str = "POST",
|
||||
max_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
retry_delay: float = 1.0,
|
||||
) -> aiohttp.ClientResponse:
|
||||
verbose=False,
|
||||
) -> Dict:
|
||||
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
|
||||
last_exception = None
|
||||
max_retries = max_retries or DEFAULT_RETRIES
|
||||
base_url = f"http://{addr}"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
timeo = aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=timeout,
|
||||
connect=timeout,
|
||||
)
|
||||
if session is None:
|
||||
_session = aiohttp.ClientSession(
|
||||
timeout=timeo,
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
connector=get_default_connector(),
|
||||
)
|
||||
else:
|
||||
_session = session
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
response = await session.get(url)
|
||||
elif method.upper() == "POST":
|
||||
response = await session.post(url, json=payload)
|
||||
elif method.upper() == "PUT":
|
||||
response = await session.put(url, json=payload)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await session.delete(url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
if verbose:
|
||||
logger.info("enter client session, start sending requests")
|
||||
if method.upper() == "GET":
|
||||
ctx = _session.get(url, timeout=timeo)
|
||||
elif method.upper() == "POST":
|
||||
ctx = _session.post(url, json=payload, timeout=timeo)
|
||||
elif method.upper() == "PUT":
|
||||
ctx = _session.put(url, json=payload, timeout=timeo)
|
||||
elif method.upper() == "DELETE":
|
||||
ctx = _session.delete(url, timeout=timeo)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
async with ctx as response:
|
||||
if verbose:
|
||||
logger.info("http requests return")
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
res = await response.json()
|
||||
if verbose:
|
||||
logger.info("get http result")
|
||||
if session is None:
|
||||
await _session.close()
|
||||
return res
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientResponseError,
|
||||
|
@ -51,6 +80,10 @@ async def arequest_with_retry(
|
|||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
if session is None:
|
||||
await _session.close()
|
||||
raise RuntimeError(
|
||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
||||
f"Failed after {max_retries} retries each. "
|
||||
f"Payload: {payload}. Addr: {addr}. Endpoint: {endpoint}. "
|
||||
f"Last error: {last_exception}"
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizerFast
|
|||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
@ -61,7 +61,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
rewards=torch.tensor([float(reward)]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
|
|
|
@ -152,11 +152,7 @@ def main_grpo():
|
|||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(
|
||||
data_generator,
|
||||
train_dataloader,
|
||||
workflow=workflow,
|
||||
)
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
|
@ -210,8 +206,9 @@ def main_grpo():
|
|||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
rollout.set_version(global_step + 1)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
rollout.set_version(global_step + 1)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
|
|
@ -1360,7 +1360,9 @@ class RayNameResolveRepository:
|
|||
|
||||
def make_repository(args: "NameResolveConfig"):
|
||||
if args.type == "nfs":
|
||||
return NfsNameRecordRepository(args.nfs_record_root)
|
||||
repo = NfsNameRecordRepository(args.nfs_record_root)
|
||||
os.makedirs(repo.record_root, exist_ok=True)
|
||||
return repo
|
||||
elif args.type == "etcd3":
|
||||
host, port = args.etcd3_addr.split(":")
|
||||
return Etcd3NameRecordRepository(host=host, port=int(port))
|
||||
|
|
Loading…
Reference in New Issue