From 29e164a69d10d843e4772fa26198723c96310d00 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:26:49 +0800 Subject: [PATCH] [Fix] [lite] Merge from the internal repo to fix GRPO bugs and refactor the train engine (#181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 * . --------- Co-authored-by: 晓雷 --- arealite/api/cli_args.py | 87 +++- arealite/api/engine_api.py | 10 +- arealite/api/io_struct.py | 25 +- arealite/engine/autotp_engine.py | 128 +++++ arealite/engine/base_hf_engine.py | 360 ++++++++++++++ arealite/engine/fsdp_engine.py | 315 +----------- arealite/engine/hf_engine.py | 467 ------------------ arealite/engine/ppo/actor.py | 7 +- arealite/engine/sft/lm_engine.py | 6 +- arealite/engine/sglang_remote.py | 108 ++-- arealite/tests/test_sglang_engine.py | 12 + arealite/tests/test_sglang_local_engine.py | 4 - .../{test_engine.py => test_train_engine.py} | 13 +- arealite/utils/data.py | 28 +- arealite/utils/functional.py | 38 +- arealite/utils/http.py | 71 ++- arealite/workflow/rlvr.py | 4 +- examples/arealite/gsm8k_grpo.py | 9 +- realhf/base/name_resolve.py | 4 +- 19 files changed, 790 insertions(+), 906 deletions(-) create mode 100644 arealite/engine/autotp_engine.py create mode 100644 arealite/engine/base_hf_engine.py delete mode 100644 arealite/engine/hf_engine.py rename arealite/tests/{test_engine.py => test_train_engine.py} (92%) diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 0f7aa9b..906f8fd 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -18,7 +18,7 @@ from arealite.utils.fs import get_user_tmp class MicroBatchSpec: """Specification for splitting micro-batches during training.""" - n_mbs: int = field( + n_mbs: Optional[int] = field( default=1, metadata={ "help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count", @@ -161,7 +161,7 @@ class FSDPEngineConfig: @dataclass -class HFEngineConfig: +class DeepSpeedAutoTPEngineConfig: autotp_size: Optional[int] = field( default=1, metadata={"help": "DeepSpeed AutoTP size"}, @@ -201,7 +201,88 @@ class TrainEngineConfig: ) backend: str = "" fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) - hf: HFEngineConfig = field(default_factory=HFEngineConfig) + ds_auto_tp: DeepSpeedAutoTPEngineConfig = field( + default_factory=DeepSpeedAutoTPEngineConfig + ) + + +@dataclass +class PPOActorConfig(TrainEngineConfig): + # Core PPO/GRPO Parameters + group_size: int = field( + default=1, metadata={"help": "Number of sequences in each group"} + ) + group_adv_norm: bool = field( + default=False, + metadata={ + "help": "Normalize advantages within each prompt group rather than globally" + }, + ) + ppo_n_minibatches: int = field( + default=4, metadata={"help": "Number of minibatches for each PPO update"} + ) + eps_clip: float = field( + default=0.2, metadata={"help": "Clipping factor for policy ratio"} + ) + c_clip: Optional[float] = field( + default=None, + metadata={ + "help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping." + }, + ) + temperature: float = field( + default=1.0, metadata={"help": "Temperature during generation."} + ) + # Reward + group_reward_norm: bool = field( + default=False, + metadata={ + "help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias" + }, + ) + reward_scaling: float = field( + default=1.0, metadata={"help": "Reward scaling factor"} + ) + reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"}) + reward_clip: float = field( + default=20.0, metadata={"help": "Maximum absolute value for reward clipping"} + ) + mask_no_eos_with_zero: bool = field( + default=False, + metadata={ + "help": "Mask truncated generations (no EOS token) and exclude from training" + }, + ) + + # Advantage Estimation + discount: float = field( + default=1.0, metadata={"help": "Discount factor for future rewards"} + ) + gae_lambda: float = field( + default=1.0, metadata={"help": "Lambda parameter for GAE"} + ) + adv_norm: bool = field( + default=True, metadata={"help": "Enable advantage normalization"} + ) + + # KL Control + kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"}) + + # Asynchronous RL + recompute_logprob: bool = field( + default=False, + metadata={"help": "Recompute logp and replace the logp returned by inference."}, + ) + use_decoupled_loss: bool = field( + default=False, + metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, + ) + behav_imp_weight_cap: Optional[float] = field( + default=None, + metadata={ + "help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true" + }, + ) @dataclass diff --git a/arealite/api/engine_api.py b/arealite/api/engine_api.py index 6334bf2..9fb24d2 100644 --- a/arealite/api/engine_api.py +++ b/arealite/api/engine_api.py @@ -25,10 +25,10 @@ class Scheduling: cpu: int gpu: int mem: int - nodelist: str = None - exclude: str = None - partition: str = None - container_image: str = None + nodelist: Optional[str] = None + exclude: Optional[str] = None + partition: Optional[str] = None + container_image: Optional[str] = None env_vars: Dict[str, str] = field(default_factory=dict) # time utils from "https://slurm.schedmd.com/sbatch.html" time_limit: Optional[str] = None # see "--time" option for format @@ -105,7 +105,7 @@ class TrainEngine(abc.ABC): def forward( self, input_: TensorDict, - output_seqlens: List[List[int]] | None = None, + output_seqlens: List[int] | None = None, post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None, aggregate_fn: Callable[[List[Any]], Any] = torch.cat, ) -> Any | None: diff --git a/arealite/api/io_struct.py b/arealite/api/io_struct.py index 83d2cc6..28a369f 100644 --- a/arealite/api/io_struct.py +++ b/arealite/api/io_struct.py @@ -71,7 +71,7 @@ class AllocationType(enum.Enum): @dataclass class AllocationMode: type_: AllocationType - parallel_strat: None | Dict[str, Dict[str, int]] + parallel_strat: Dict[str, Dict[str, int]] @property def gen_tp_size(self) -> int: @@ -115,7 +115,7 @@ class AllocationMode: raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}") @staticmethod - def extract_3d_alloc(allocation_mode: str) -> Dict | None: + def extract_parallelism_strategy(allocation_mode: str) -> Dict: for x, y, z in itertools.permutations(["d", "t", "p"]): pattern = rf"{x}(\d+){y}(\d+){z}(\d+)" m = re.match(pattern, allocation_mode) @@ -130,29 +130,28 @@ class AllocationMode: z: c, } } + raise ValueError( + f"Unknown how to resolve parallelism strategy: {allocation_mode}" + ) @staticmethod - def extract_decoupled_alloc(allocation_mode: str) -> Dict | None: + def extract_decoupled_alloc(allocation_mode: str) -> Dict: pattern = re.compile( r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))" ) m = pattern.match(allocation_mode) if not m: - return + raise ValueError( + f"Unknown how to resolve decoupled allocation: {allocation_mode}" + ) if m.group(1): gen_alloc = m.group(1) other_alloc = m.group(2) else: gen_alloc = m.group(4) other_alloc = m.group(3) - gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc) - if not gen_alloc: - return - other_alloc = AllocationMode.extract_3d_alloc( - other_alloc - ) or AllocationMode.extract_key_value_alloc(other_alloc) - if not other_alloc: - return + gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc) + other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc) other_alloc.update({"gen": gen_alloc["*"]}) return other_alloc @@ -171,7 +170,7 @@ class SaveLoadMeta: path: str weight_format: str with_optim: bool - tokenizer: PreTrainedTokenizerFast | None + tokenizer: Optional[PreTrainedTokenizerFast] base_model_path: str | None naive_distributed: bool = False diff --git a/arealite/engine/autotp_engine.py b/arealite/engine/autotp_engine.py new file mode 100644 index 0000000..e068b9c --- /dev/null +++ b/arealite/engine/autotp_engine.py @@ -0,0 +1,128 @@ +import os + +import torch +import torch.distributed as dist +from safetensors.torch import save_file + +from arealite.api.cli_args import TrainEngineConfig +from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta +from arealite.engine.base_hf_engine import BaseHFEngine +from arealite.utils.save_load import ( + get_state_dict_from_repo_id_or_path, + is_existing_local_path, +) +from realhf.base import constants, logging + +logger = logging.getLogger("DeepSpeedAutoTPEngine") + + +class DeepSpeedAutoTPEngine(BaseHFEngine): + def __init__(self, config: TrainEngineConfig): + super().__init__(config) + + def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): + """Initialize distributed communication and model.""" + assert ( + addr is None + ), "DeepSpeedAutoTPEngine does not support remote initialization." + import deepspeed + + self.create_process_group() + + world_size = int(os.environ.get("WORLD_SIZE")) + deepspeed.init_distributed( + dist_backend="nccl", + world_size=world_size, + timeout=constants.NCCL_DEFAULT_TIMEOUT, + ) + self.create_device_model() + # NOTE: the device context manager does not work here. + self.model = deepspeed.tp_model_init( + self.model, + tp_size=self.config.ds_auto_tp.autotp_size, + dtype=getattr(torch, self.config.dtype), + ).to(self.device) + self.create_optimizer(ft_spec) + self.initialized = True + + def _check_autotp(self): + tp_size = self.config.ds_auto_tp.autotp_size + config = self.model_config + num_attention_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + return ( + num_attention_heads % tp_size == 0 + and num_key_value_heads % tp_size == 0 + and hidden_size % tp_size == 0 + and intermediate_size % tp_size == 0 + ) + + def save(self, meta: SaveLoadMeta): + if meta.weight_format != "naive_distributed": + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + if self.model is None: + raise RuntimeError("Model not initialized") + + rank = dist.get_rank() + world_size = dist.get_world_size() + if rank == 0: + os.makedirs(meta.path, exist_ok=True) + self.model_config.save_pretrained( + meta.path, + ) + if meta.tokenizer is not None: + meta.tokenizer.save_pretrained( + meta.path, + ) + + state_dict = self.model.state_dict() + if hasattr(self.model, "module"): + state_dict = { + k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu() + for k, v in state_dict.items() + } + else: + state_dict = {k: v.cpu() for k, v in state_dict.items()} + + # Only support store parameters from model partitions respectively + gathered_state_dicts = None + if rank == 0: + gathered_state_dicts = [None for _ in range(world_size)] + dist.gather_object( + obj=state_dict, object_gather_list=gathered_state_dicts, dst=0 + ) + if rank == 0: + for i, state_dict in enumerate(gathered_state_dicts): + save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors") + if meta.with_optim: + self.save_optimizer_state(meta.path) + + def load(self, meta: SaveLoadMeta): + if meta.weight_format != "naive_distributed": + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + rank = dist.get_rank() + # Only support load full model parameters from huggingface + # and load model partition locally + if rank == 0 or is_existing_local_path(meta.path): + path = f"{meta.path}/rank_{rank:02d}_model.safetensors" + full_state = get_state_dict_from_repo_id_or_path(meta.path) + + if hasattr(self.model, "module") and not hasattr(full_state): + full_state = { + f"module.{k}" if not k.startswith("module.") else k: v + for k, v in full_state.items() + } + self.model.load_state_dict( + full_state, strict=not self.model_config.tie_word_embeddings + ) + if self.model_config.tie_word_embeddings: + self.model.tie_weights() + + if meta.with_optim: + self.load_optimizer_state(meta.path) + + def upload_weights(self, meta: WeightUpdateMeta): + raise ValueError(f"update weight not implemented {meta.type}") diff --git a/arealite/engine/base_hf_engine.py b/arealite/engine/base_hf_engine.py new file mode 100644 index 0000000..bcbda8b --- /dev/null +++ b/arealite/engine/base_hf_engine.py @@ -0,0 +1,360 @@ +import gc +import os +import time +from typing import Any, Callable, Dict, List + +import torch +import torch.distributed as dist +from tensordict import TensorDict +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedTokenizerFast, + get_constant_schedule_with_warmup, + get_linear_schedule_with_warmup, +) + +from arealite.api.cli_args import TrainEngineConfig +from arealite.api.engine_api import FinetuneSpec, TrainEngine +from arealite.utils.data import ( + MicroBatchList, + amend_position_ids, + pack_tensor_dict, + pad_and_stack_tensors_along_first_dim, + pad_mb_list, + reorder_list, + split_padded_tensor_dict_into_mb_list, + unpack_sequence, + unsqueeze_mb_list, +) +from arealite.utils.fsdp import get_cosine_schedule_with_warmup +from arealite.utils.model import disable_dropout_in_model +from realhf.api.core.data_api import load_hf_tokenizer +from realhf.base import constants, logging + +logger = logging.getLogger("Base HF Engine") + + +class BaseHFEngine(TrainEngine): + def __init__(self, config: TrainEngineConfig): + self.config = config + self.optimizer_config = config.optimizer + + self.model: torch.nn.Module + self.optimizer: torch.optim.Optimizer + self.tokenizer: PreTrainedTokenizerFast + # huggingface model config + self.model_config: PretrainedConfig + + # initialization + self.initialized = False + self.own_global_group = False + self._parallelism_group: dist.ProcessGroup + self.weight_update_group_initialized = False + + self.world_size = int(os.environ["WORLD_SIZE"]) + + def train(self, mode: bool = True): + assert self.model is not None + self.model.train(mode=mode) + return self + + @property + def parallelism_group(self) -> dist.ProcessGroup: + assert self.initialized + return self._parallelism_group + + def create_process_group(self): + if not dist.is_initialized(): + # TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher + dist.init_process_group( + backend="nccl", + timeout=constants.NCCL_DEFAULT_TIMEOUT, + device_id=torch.device(int(os.environ["LOCAL_RANK"])), + ) + self.own_global_group = True + self._parallelism_group = dist.new_group() + + def create_device_model(self): + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.device(int(os.environ["LOCAL_RANK"])) + + dtype = getattr(torch, self.config.dtype) + self.model_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.config.path, + trust_remote_code=True, + ) + self.tokenizer = load_hf_tokenizer(self.config.path) + tik = time.perf_counter() + with torch.device("cuda"): + if self.config.init_from_scratch: + # initialize scratch model from config + # NOTE: VLM cannot directly load state dict using this + # random initialized model, so otherwise we call + # from_pretrained rather than loading weights into this random model. + model = AutoModelForCausalLM.from_config( + self.model_config, + torch_dtype=dtype, + attn_implementation=self.config.attn_impl, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=self.config.path, + trust_remote_code=True, + torch_dtype=dtype, + attn_implementation=self.config.attn_impl, + ) + if self.config.disable_dropout: + disable_dropout_in_model(model) + if self.config.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + logger.info(f"Model creation and loading time: {time.perf_counter() - tik}") + self.model = model + + def create_optimizer(self, ft_spec: FinetuneSpec): + if self.optimizer_config is None: + return + assert self.model is not None + # Set up optimizer + tik = time.perf_counter() + assert ( + self.optimizer_config.type == "adam" + ), "Only AdamW optimizer is supported in this engine." + lr = self.optimizer_config.lr + weight_decay = self.optimizer_config.weight_decay + beta1 = self.optimizer_config.beta1 + beta2 = self.optimizer_config.beta2 + eps = self.optimizer_config.eps + + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + ) + total_train_steps = ft_spec.total_train_steps + num_warmup_steps = int( + self.optimizer_config.warmup_steps_proportion * total_train_steps + ) + + if self.optimizer_config.lr_scheduler_type == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + min_lr_ratio=self.optimizer_config.min_lr_ratio, + ) + elif self.optimizer_config.lr_scheduler_type == "linear": + self.lr_scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + ) + elif self.optimizer_config.lr_scheduler_type == "constant": + self.lr_scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + ) + else: + raise ValueError( + f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" + ) + logger.info(f"Create optimizer time: {time.perf_counter() - tik}") + + def destroy(self): + """Destroy the engine and release GPU memory.""" + del self.optimizer + del self.model + gc.collect() + torch.cuda.empty_cache() + gc.collect() + dist.destroy_process_group(self.parallelism_group) + if self.own_global_group: + dist.destroy_process_group() + self.initialized = False + + def save_optimizer_state(self, path: str): + # Save FSDP sharded state dict on each rank + assert self.optimizer is not None + assert dist.is_initialized() + rank = dist.get_rank() + shard_path = os.path.join( + path, f"optim_world_size_{self.world_size}_rank_{rank}.pt" + ) + state_dict = self.optimizer.state_dict() + torch.save(state_dict, shard_path) + dist.barrier() + + def load_optimizer_state(self, path: str): + # Load FSDP sharded state dict + assert self.optimizer is not None + assert dist.is_initialized() + rank = dist.get_rank() + shard_path = os.path.join( + path, f"optim_world_size_{self.world_size}_rank_{rank}.pt" + ) + optimizer_state_dict = torch.load(shard_path, weights_only=False) + self.optimizer.load_state_dict(optimizer_state_dict) + dist.barrier() + + def step_lr_scheduler(self): + assert self.lr_scheduler is not None + self.lr_scheduler.step() + + def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList: + assert "attention_mask" in input_ and "input_ids" in input_ + if isinstance(input_, dict): + input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]]) + input_ = amend_position_ids(input_) + mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec) + logger.info( + f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}" + ) + mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs] + mb_list = pad_mb_list(mb_list, pad_value=0.0) + # NOTE: We unsqueeze here because huggingface transformer models requires + # packed input to be of shape [1, total_seqlen]. + mb_list = unsqueeze_mb_list(mb_list) + # FIXME: the resulting max_seqlen is a tensor rather than an integer + for mb in mb_list.mbs: + mb["max_seqlen"] = int(mb["max_seqlen"]) + mb["use_cache"] = False + for mb in mb_list.padded_mbs: + mb["max_seqlen"] = int(mb["max_seqlen"]) + mb["use_cache"] = False + return mb_list + + def train_batch( + self, + input_: TensorDict, + loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor], + loss_weight_fn: Callable[[TensorDict], float], + ) -> Dict[str, float]: + """Train on a batch using gradient accumulation.""" + input_ = input_.to(self.device) + assert self.optimizer is not None + assert self.optimizer_config is not None + assert self.lr_scheduler is not None + + self.optimizer.zero_grad() + mb_list = self.prepare_mb_list(input_) + + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + dist.all_reduce(total_loss_weight) + + # Process microbatches with gradient accumulation + for i, (pad_length, padded_mb_input, mb_input) in enumerate( + zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs) + ): + outputs = self.model(**padded_mb_input) + + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + loss = loss_fn(logits, mb_input) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + + # Scale loss for accumulation + # Revert gradient averaging across dp ranks + loss_scale *= self.world_size + + loss *= loss_scale + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.optimizer_config.gradient_clipping, + norm_type=2.0, + error_if_nonfinite=False, + foreach=None, + ) + if not torch.isfinite(grad_norm): + self.optimizer.zero_grad() + update_successful = False + else: + self.optimizer.step() + update_successful = True + + current_lr = self.lr_scheduler.get_last_lr()[0] + # Optimizer step + self.optimizer.step() + return dict( + update_successful=float(update_successful), + grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), + lr=current_lr, + ) + + @torch.no_grad() + def eval_batch( + self, + input_: TensorDict, + loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor], + loss_weight_fn: Callable[[TensorDict], float], + ) -> torch.Tensor | None: + """Evaluate on a batch.""" + input_ = input_.to(self.device) + mb_list = self.prepare_mb_list(input_) + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + + total_loss = 0.0 + total_weight = 0.0 + + for pad_length, padded_mb_input, mb_input in zip( + mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs + ): + outputs = self.model(**padded_mb_input) + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + loss = loss_fn(logits, mb_input) + + # Simple weight calculation (could be improved) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + total_loss += loss.item() * loss_scale + total_weight += loss_scale + + return torch.tensor(total_loss / total_weight) + + @torch.no_grad() + def forward( + self, + input_: TensorDict, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = torch.cat, + ) -> Any | None: + """Forward pass with optional post-processing.""" + input_ = input_.to(self.device) + cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + mb_list = self.prepare_mb_list(input_) + + if output_seqlens is None: + output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + + results = [] + for pad_length, padded_mb_input, mb_input in zip( + mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs + ): + outputs = self.model(**padded_mb_input) + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + + if post_hook: + result = post_hook(logits, mb_input) + results.append(result) + else: + results.append(logits) + + res = aggregate_fn(results) + output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices] + unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) + reordered = reorder_list(unpacked, mb_list.backward_indices) + return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/arealite/engine/fsdp_engine.py b/arealite/engine/fsdp_engine.py index 07efc36..34131ec 100644 --- a/arealite/engine/fsdp_engine.py +++ b/arealite/engine/fsdp_engine.py @@ -1,42 +1,20 @@ -import gc import os import time from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, Optional import torch import torch.distributed as dist -import transformers from tensordict import TensorDict from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, ) -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, -) +from transformers import PreTrainedTokenizerFast from arealite.api.cli_args import TrainEngineConfig -from arealite.api.engine_api import ( - FinetuneSpec, - SaveLoadMeta, - TrainEngine, - WeightUpdateMeta, -) -from arealite.utils.data import ( - MicroBatchList, - amend_position_ids, - pack_tensor_dict, - pad_and_stack_tensors_along_first_dim, - pad_mb_list, - reorder_list, - split_padded_tensor_dict_into_mb_list, - unpack_sequence, - unsqueeze_mb_list, -) +from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta +from arealite.engine.base_hf_engine import BaseHFEngine from arealite.utils.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, @@ -44,108 +22,35 @@ from arealite.utils.fsdp import ( create_fsdp_device_mesh, fsdp2_clip_grad_norm_, fsdp2_load_full_state_dict, - get_cosine_schedule_with_warmup, ) -from arealite.utils.model import disable_dropout_in_model from arealite.utils.save_load import get_state_dict_from_repo_id_or_path -from realhf.api.core.data_api import load_hf_tokenizer -from realhf.base import constants, logging, name_resolve, names, pkg_version +from realhf.base import logging, name_resolve, names, pkg_version logger = logging.getLogger("FSDPEngine") -class FSDPEngine(TrainEngine): +class FSDPEngine(BaseHFEngine): def __init__(self, config: TrainEngineConfig): - self.config = config - self.optimizer_config = config.optimizer - - self.model = None - self.optimizer = None - self.tokenizer = None - # huggingface model config - self.model_config = None + super().__init__(config) # FSDP options self.mixed_precision_policy = None self.device_mesh = None self.cpu_offload = None - # initialization - self.initialized = False - self.own_global_group = False - self._parallelism_group = None - self.weight_update_group_initialized = False - - # TODO: Handle the case when WORLD_SIZE is not set in launcher - self.world_size = int(os.environ["WORLD_SIZE"]) - - def train(self, mode: bool = True): - assert self.model is not None - self.model.train(mode=mode) - return self - - @property - def parallelism_group(self) -> dist.ProcessGroup: - assert self.initialized - return self._parallelism_group def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): # Initialize distributed enviroments and load model. assert addr is None, "FSDPEngine does not support remote initialization." - assert pkg_version.is_version_greater_or_equal( "torch", "2.4.0" ), f"arealite only supports FSDP2, which requires torch>=2.4.0" - """Initialize distributed communication and model.""" - if not dist.is_initialized(): - # TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher - dist.init_process_group( - backend="nccl", - timeout=constants.NCCL_DEFAULT_TIMEOUT, - device_id=torch.device(int(os.environ["LOCAL_RANK"])), - ) - self.own_global_group = True - self._parallelism_group = dist.new_group() - - # TODO: Handle the condition when LOCAL_RANK is not set in launcher - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.device(int(os.environ["LOCAL_RANK"])) - - dtype = getattr(torch, self.config.dtype) - self.model_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path=self.config.path, - trust_remote_code=True, - ) - self.tokenizer = load_hf_tokenizer(self.config.path) - tik = time.perf_counter() - with torch.device("cuda"): - if self.config.init_from_scratch: - # initialize scratch model from config - # NOTE: VLM cannot directly load state dict using this - # random initialized model, so otherwise we call - # from_pretrained rather than loading weights into this random model. - model = AutoModelForCausalLM.from_config( - self.model_config, - torch_dtype=dtype, - attn_implementation=self.config.attn_impl, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.config.path, - trust_remote_code=True, - torch_dtype=dtype, - attn_implementation=self.config.attn_impl, - ) - if self.config.disable_dropout: - disable_dropout_in_model(model) - if self.config.gradient_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs={"use_reentrant": False} - ) - logger.info(f"Model creation and loading time: {time.perf_counter() - tik}") + self.create_process_group() + self.create_device_model() + # Wrap with FSDP2 # Simple auto wrap policy self.mixed_precision_policy = MixedPrecisionPolicy( - param_dtype=dtype, + param_dtype=getattr(torch, self.config.dtype), reduce_dtype=torch.float32, cast_forward_inputs=True, ) @@ -154,82 +59,19 @@ class FSDPEngine(TrainEngine): self.cpu_offload = ( CPUOffloadPolicy() if self.config.fsdp.offload_params else None ) - fsdp_kwargs = { "mesh": self.device_mesh, "mp_policy": self.mixed_precision_policy, "offload_policy": self.cpu_offload, "reshard_after_forward": True, } - - # Wrap with FSDP2 tik = time.perf_counter() - apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy) + apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy) logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}") - self.model = model - - # Set up optimizer - if self.optimizer_config is not None: - tik = time.perf_counter() - assert ( - self.optimizer_config.type == "adam" - ), "Only AdamW optimizer is supported in this engine." - lr = self.optimizer_config.lr - weight_decay = self.optimizer_config.weight_decay - beta1 = self.optimizer_config.beta1 - beta2 = self.optimizer_config.beta2 - eps = self.optimizer_config.eps - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=lr, - weight_decay=weight_decay, - betas=(beta1, beta2), - eps=eps, - ) - total_train_steps = ft_spec.total_train_steps - num_warmup_steps = int( - self.optimizer_config.warmup_steps_proportion * total_train_steps - ) - - if self.optimizer_config.lr_scheduler_type == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - min_lr_ratio=self.optimizer_config.min_lr_ratio, - ) - elif self.optimizer_config.lr_scheduler_type == "linear": - self.lr_scheduler = get_linear_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - ) - elif self.optimizer_config.lr_scheduler_type == "constant": - self.lr_scheduler = get_constant_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - ) - else: - raise ValueError( - f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" - ) - logger.info(f"Create optimizer time: {time.perf_counter() - tik}") + self.create_optimizer(ft_spec) self.initialized = True - def destroy(self): - """Destroy the engine and release GPU memory.""" - self.model = None - self.optimizer = None - gc.collect() - torch.cuda.empty_cache() - gc.collect() - dist.destroy_process_group(self.parallelism_group) - if self.own_global_group: - dist.destroy_process_group() - self.initialized = False - def save(self, meta: SaveLoadMeta): if meta.weight_format == "hf": self._save_model_to_hf(meta.path, meta.tokenizer) @@ -240,7 +82,7 @@ class FSDPEngine(TrainEngine): raise ValueError(f"Unknown weight format {meta.weight_format}. ") if meta.with_optim: - self._save_optimizer_state(meta.path) + self.save_optimizer_state(meta.path) def load(self, meta: SaveLoadMeta): if meta.weight_format == "hf": @@ -252,34 +94,10 @@ class FSDPEngine(TrainEngine): raise ValueError(f"Unknown weight format {meta.weight_format}. ") if meta.with_optim: - self._load_optimizer_state(meta.path) - - def _save_optimizer_state(self, path: str): - # Save FSDP sharded state dict on each rank - assert self.optimizer is not None - assert dist.is_initialized() - rank = dist.get_rank() - shard_path = os.path.join( - path, f"optim_world_size_{self.world_size}_rank_{rank}.pt" - ) - state_dict = self.optimizer.state_dict() - torch.save(state_dict, shard_path) - dist.barrier() - - def _load_optimizer_state(self, path: str): - # Load FSDP sharded state dict - assert self.optimizer is not None - assert dist.is_initialized() - rank = dist.get_rank() - shard_path = os.path.join( - path, f"optim_world_size_{self.world_size}_rank_{rank}.pt" - ) - optimizer_state_dict = torch.load(shard_path, weights_only=False) - self.optimizer.load_state_dict(optimizer_state_dict) - dist.barrier() + self.load_optimizer_state(meta.path) def _save_model_to_hf( - self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast] + self, path: str, tokenizer: Optional[PreTrainedTokenizerFast] ): """Save model in HuggingFace format.""" if self.model is None: @@ -345,35 +163,11 @@ class FSDPEngine(TrainEngine): "Distributed weight update is not implemented for FSDPEngine yet. " ) - def step_lr_scheduler(self): - assert self.lr_scheduler is not None - self.lr_scheduler.step() - - def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList: - assert "attention_mask" in input_ and "input_ids" in input_ - if isinstance(input_, dict): - input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]]) - input_ = amend_position_ids(input_) - mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec) - logger.info( - f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}" - ) - mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs] - mb_list = pad_mb_list(mb_list, pad_value=0.0) - # NOTE: We unsqueeze here because huggingface transformer models requires - # packed input to be of shape [1, total_seqlen]. - mb_list = unsqueeze_mb_list(mb_list) - # FIXME: the resulting max_seqlen is a tensor rather than an integer - for mb in mb_list.mbs: - mb["max_seqlen"] = int(mb["max_seqlen"]) - mb["use_cache"] = False - return mb_list - def train_batch( self, input_: TensorDict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], + loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor], + loss_weight_fn: Callable[[TensorDict], float], ) -> Dict[str, float]: """Train on a batch using gradient accumulation.""" input_ = input_.to(self.device) @@ -382,7 +176,7 @@ class FSDPEngine(TrainEngine): assert self.lr_scheduler is not None self.optimizer.zero_grad() - mb_list = self._prepare_mb_list(input_) + mb_list = self.prepare_mb_list(input_) total_loss_weight = torch.tensor( sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 @@ -394,7 +188,6 @@ class FSDPEngine(TrainEngine): for i, (pad_length, padded_mb_input, mb_input) in enumerate( zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs) ): - self.model.set_is_last_backward(i == len(mb_list.mbs) - 1) outputs = self.model(**padded_mb_input) logits = outputs.logits.squeeze(0) @@ -409,6 +202,7 @@ class FSDPEngine(TrainEngine): loss *= loss_scale loss.backward() + # NOTE: grad norm clip function is different grad_norm = fsdp2_clip_grad_norm_( self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping ) @@ -427,72 +221,3 @@ class FSDPEngine(TrainEngine): grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), lr=current_lr, ) - - @torch.no_grad() - def eval_batch( - self, - input_: TensorDict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], - ) -> torch.Tensor | None: - """Evaluate on a batch.""" - input_ = input_.to(self.device) - mb_list = self._prepare_mb_list(input_) - total_loss_weight = torch.tensor( - sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 - ) - assert total_loss_weight != 0 - - total_loss = 0.0 - total_weight = 0.0 - - for pad_length, padded_mb_input, mb_input in zip( - mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs - ): - outputs = self.model(**padded_mb_input) - logits = outputs.logits.squeeze(0) - logits = logits[:-pad_length] if pad_length > 0 else logits - loss = loss_fn(logits, mb_input) - - # Simple weight calculation (could be improved) - loss_scale = loss_weight_fn(mb_input) / total_loss_weight - total_loss += loss.item() * loss_scale - total_weight += loss_scale - - return torch.tensor(total_loss / total_weight) - - @torch.no_grad() - def forward( - self, - input_: TensorDict, - output_seqlens: List[int] | None = None, - post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, - aggregate_fn: Callable[[List[Any]], Any] = torch.cat, - ) -> Any | None: - """Forward pass with optional post-processing.""" - input_ = input_.to(self.device) - cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] - mb_list = self._prepare_mb_list(input_) - - if output_seqlens is None: - output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() - - results = [] - for pad_length, padded_mb_input, mb_input in zip( - mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs - ): - outputs = self.model(**padded_mb_input) - logits = outputs.logits.squeeze(0) - logits = logits[:-pad_length] if pad_length > 0 else logits - - if post_hook: - result = post_hook(logits, mb_input) - results.append(result) - else: - results.append(logits) - - res = aggregate_fn(results) - output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices] - unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) - reordered = reorder_list(unpacked, mb_list.backward_indices) - return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/arealite/engine/hf_engine.py b/arealite/engine/hf_engine.py deleted file mode 100644 index ee93f9f..0000000 --- a/arealite/engine/hf_engine.py +++ /dev/null @@ -1,467 +0,0 @@ -import gc -import os -import time -from typing import Any, Callable, Dict, List, Optional - -import torch -import torch.distributed as dist -import transformers -from safetensors.torch import save_file -from tensordict import TensorDict -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, -) - -from arealite.api.cli_args import TrainEngineConfig -from arealite.api.engine_api import ( - FinetuneSpec, - SaveLoadMeta, - TrainEngine, - WeightUpdateMeta, -) -from arealite.utils.data import ( - MicroBatchList, - amend_position_ids, - pack_tensor_dict, - pad_and_stack_tensors_along_first_dim, - pad_mb_list, - reorder_list, - split_packed_tensor_dict_into_mb_list, - unpack_sequence, - unsqueeze_mb_list, -) -from arealite.utils.fsdp import get_cosine_schedule_with_warmup -from arealite.utils.save_load import ( - get_state_dict_from_repo_id_or_path, - is_existing_local_path, -) -from realhf.api.core.data_api import load_hf_tokenizer -from realhf.base import logging, name_resolve, names - -logger = logging.getLogger("HFEngine") - - -class HFEngine(TrainEngine): - def __init__(self, config: TrainEngineConfig): - self.config = config - self.optimizer_config = config.optimizer - - self.model = None - self.optimizer = None - self.tokenizer = None - # huggingface model config - self.model_config = None - # initialization - self.initialized = False - self.weight_update_group_initialized = False - - def train(self, mode: bool = True): - assert self.model is not None - self.model.train(mode=mode) - return self - - def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): - """Initialize distributed communication and model.""" - assert addr is None, "HFEngine does not support remote initialization." - - world_size = int(os.environ.get("WORLD_SIZE", 0)) - if not dist.is_initialized() and world_size > 1: - try: - import deepspeed - except ImportError: - print( - "Warning: deepspeed is not installed. Some functionality may be disabled." - ) - deepspeed.init_distributed(dist_backend="nccl", world_size=world_size) - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - self.device = torch.device(f"cuda:{local_rank}") - - dtype = getattr(torch, self.config.dtype) - self.model_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path=self.config.path, - trust_remote_code=True, - ) - self.tokenizer = load_hf_tokenizer(self.config.path) - - self.model = AutoModelForCausalLM.from_config( - self.model_config, - torch_dtype=dtype, - attn_implementation=self.config.attn_impl, - ).to(f"cuda:{local_rank}") - - if not self.config.init_from_scratch: - # Load model from a initial checkpoint path, - # which should only be a huggingface checkpoint. - load_meta = SaveLoadMeta( - path=self.config.path, - weight_format="hf", - with_optim=False, - tokenizer=None, - base_model_path=self.config.path, - naive_distributed=False, - ) - - self.load(load_meta) - - if world_size > 1: - if self._check_autotp(): - self.model = deepspeed.tp_model_init( - self.model, tp_size=self.config.hf.autotp_size, dtype=dtype - ) - else: - raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ") - - # Set up optimizer - if self.optimizer_config is not None: - assert ( - self.optimizer_config.type == "adam" - ), "Only AdamW optimizer is supported in this engine." - lr = self.optimizer_config.lr - weight_decay = self.optimizer_config.weight_decay - beta1 = self.optimizer_config.beta1 - beta2 = self.optimizer_config.beta2 - eps = self.optimizer_config.eps - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=lr, - weight_decay=weight_decay, - betas=(beta1, beta2), - eps=eps, - ) - total_train_steps = ft_spec.total_train_steps - num_warmup_steps = int( - self.optimizer_config.warmup_steps_proportion * total_train_steps - ) - - if self.optimizer_config.lr_scheduler_type == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - min_lr_ratio=self.optimizer_config.min_lr_ratio, - ) - elif self.optimizer_config.lr_scheduler_type == "linear": - self.lr_scheduler = get_linear_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - ) - elif self.optimizer_config.lr_scheduler_type == "constant": - self.lr_scheduler = get_constant_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - ) - else: - raise ValueError( - f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" - ) - - self.initialized = True - - def _check_autotp(self): - tp_size = self.config.hf.autotp_size - config = self.model_config - num_attention_heads = config.num_attention_heads - num_key_value_heads = config.num_key_value_heads - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - - return ( - num_attention_heads % tp_size == 0 - and num_key_value_heads % tp_size == 0 - and hidden_size % tp_size == 0 - and intermediate_size % tp_size == 0 - ) - - def destroy(self): - """Destroy the engine and release GPU memory.""" - self.model = None - self.optimizer = None - gc.collect() - torch.cuda.empty_cache() - gc.collect() - self.initialized = False - - def save(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed) - elif meta.weight_format == "dcp": - # TODO: implement DCP save/load for HF - raise NotImplementedError("DCP format saving is not implemented yet. ") - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") - - if meta.with_optim: - self._save_optimizer_state(meta.path) - - def load(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - self._load_model_from_hf(meta.path, meta.naive_distributed) - elif meta.weight_format == "dcp": - # TODO: implement DCP save/load for HF - raise NotImplementedError("DCP format loading is not implemented yet. ") - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") - - if meta.with_optim: - self._load_optimizer_state(meta.path) - - def _save_optimizer_state(self, path: str): - assert self.optimizer is not None - os.makedirs(path, exist_ok=True) - torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) - - def _load_optimizer_state(self, path: str): - assert self.optimizer is not None - path = os.path.join(path, "optim.pt") - optimizer_state_dict = torch.load(path, weights_only=False) - self.optimizer.load_state_dict(optimizer_state_dict) - - def _save_model_to_hf( - self, - path: str, - tokenizer: Optional[transformers.PreTrainedTokenizerFast], - naive_distributed: bool, - ): - """Save model in HuggingFace format.""" - if self.model is None: - raise RuntimeError("Model not initialized") - - rank = dist.get_rank() - world_size = dist.get_world_size() - if rank == 0: - os.makedirs(path, exist_ok=True) - self.model_config.save_pretrained(path) - if tokenizer is not None: - tokenizer.save_pretrained(path) - - if world_size > 1: - dist.barrier() - - state_dict = self.model.state_dict() - - if hasattr(self.model, "module"): - state_dict = { - k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu() - for k, v in state_dict.items() - } - else: - state_dict = {k: v.cpu() for k, v in state_dict.items()} - - if world_size > 1 and naive_distributed: - # Only support store parameters from model partitions respectively - gathered_state_dicts = None - if rank == 0: - gathered_state_dicts = [None for _ in range(world_size)] - - dist.gather_object( - obj=state_dict, object_gather_list=gathered_state_dicts, dst=0 - ) - - if rank == 0: - for i, state_dict in enumerate(gathered_state_dicts): - save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors") - else: - self.model.save_pretrained(path, state_dict=state_dict) - - if world_size > 1: - dist.barrier() - - def _load_model_from_hf(self, path: str, naive_distributed: bool): - """Load model from HuggingFace format.""" - - rank = dist.get_rank() - # Only support load full model parameters from huggingface - # and load model partition locally - if rank == 0 or is_existing_local_path(path): - if naive_distributed: - path = f"{path}/rank_{rank:02d}_model.safetensors" - full_state = get_state_dict_from_repo_id_or_path(path) - - if hasattr(self.model, "module") and not hasattr(full_state): - full_state = { - f"module.{k}" if not k.startswith("module.") else k: v - for k, v in full_state.items() - } - self.model.load_state_dict( - full_state, strict=not self.model_config.tie_word_embeddings - ) - - if self.model_config.tie_word_embeddings: - self.model.tie_weights() - - def upload_weights(self, meta: WeightUpdateMeta): - if meta.type == "nccl": - if not self.weight_update_group_initialized: - self._init_distributed_weight_update(meta) - self._update_weights_from_distributed() - elif meta.type == "disk": - self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed) - update_name = names.update_weights_from_disk( - self.config.experiment_name, - self.config.trial_name, - meta.model_version, - ) - name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120) - else: - raise ValueError(f"Unknown weight update type {meta.type}") - - def _init_distributed_weight_update(self, meta: WeightUpdateMeta): - raise NotImplementedError( - "Distributed weight update is not implemented for HFEngine yet. " - ) - - def _update_weights_from_distributed(self): - raise NotImplementedError( - "Distributed weight update is not implemented for HFEngine yet. " - ) - - def step_lr_scheduler(self): - assert self.lr_scheduler is not None - return self.lr_scheduler.step() - - def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList: - assert "attention_mask" in input_ and "input_ids" in input_ - if isinstance(input_, dict): - input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]]) - input_ = amend_position_ids(input_) - packed_input = pack_tensor_dict(input_) - mb_list = split_packed_tensor_dict_into_mb_list( - packed_input, - self.config.mb_spec, - ) - mb_list = pad_mb_list(mb_list, pad_value=0.0) - # NOTE: We unsqueeze here because huggingface transformer models requires - # packed input to be of shape [1, total_seqlen]. - mb_list = unsqueeze_mb_list(mb_list) - return mb_list - - def train_batch( - self, - input_: TensorDict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], - ) -> Dict[str, float]: - """Train on a batch using gradient accumulation.""" - input_ = input_.to(self.device) - assert self.optimizer is not None - assert self.optimizer_config is not None - assert self.lr_scheduler is not None - - self.optimizer.zero_grad() - mb_list = self._prepare_mb_list(input_) - - total_loss_weight = torch.tensor( - sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 - ) - assert total_loss_weight != 0 - - # Process microbatches with gradient accumulation - for pad_length, padded_mb_input, mb_input in zip( - mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs - ): - outputs = self.model(**padded_mb_input) - - logits = outputs.logits.squeeze(0) - logits = logits[:-pad_length] if pad_length > 0 else logits - loss = loss_fn(logits, mb_input) - loss_scale = loss_weight_fn(mb_input) / total_loss_weight - - loss *= loss_scale - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.optimizer_config.gradient_clipping, - norm_type=2.0, - error_if_nonfinite=False, - foreach=None, - ) - if not torch.isfinite(grad_norm): - self.optimizer.zero_grad() - update_successful = False - else: - self.optimizer.step() - update_successful = True - - current_lr = self.lr_scheduler.get_last_lr()[0] - # Optimizer step - self.optimizer.step() - return dict( - update_successful=float(update_successful), - grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), - lr=current_lr, - ) - - @torch.no_grad() - def eval_batch( - self, - input_: TensorDict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], - ) -> torch.Tensor | None: - """Evaluate on a batch.""" - mb_list = self._prepare_mb_list(input_) - total_loss_weight = torch.tensor( - sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 - ) - assert total_loss_weight != 0 - - total_loss = 0.0 - total_weight = 0.0 - - for pad_length, padded_mb_input, mb_input in zip( - mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs - ): - outputs = self.model(**padded_mb_input) - logits = outputs.logits.squeeze(0) - logits = logits[:-pad_length] if pad_length > 0 else logits - loss = loss_fn(logits, mb_input) - - # Simple weight calculation (could be improved) - loss_scale = loss_weight_fn(mb_input) / total_loss_weight - total_loss += loss.item() * loss_scale - total_weight += loss_scale - - return torch.tensor(total_loss / total_weight) - - @torch.no_grad() - def forward( - self, - input_: TensorDict, - output_seqlens: List[int] | None = None, - post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, - aggregate_fn: Callable[[List[Any]], Any] = torch.cat, - ) -> Any | None: - """Forward pass with optional post-processing.""" - cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] - mb_list = self._prepare_mb_list(input_) - - if output_seqlens is None: - output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() - - results = [] - for pad_length, padded_mb_input, mb_input in zip( - mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs - ): - outputs = self.model(**padded_mb_input) - logits = outputs.logits.squeeze(0) - logits = logits[:-pad_length] if pad_length > 0 else logits - - if post_hook: - result = post_hook(logits, mb_input) - results.append(result) - else: - results.append(logits) - - res = aggregate_fn(results) - output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices] - unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) - reordered = reorder_list(unpacked, mb_list.backward_indices) - return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/arealite/engine/ppo/actor.py b/arealite/engine/ppo/actor.py index 937f32d..802c03e 100644 --- a/arealite/engine/ppo/actor.py +++ b/arealite/engine/ppo/actor.py @@ -114,7 +114,9 @@ class PPOActor: values = torch.zeros_like(rewards) else: values = data["values"] - advantages_reversed = [] + advantages_reversed = [ + torch.zeros(bs, dtype=torch.float32, device=values.device) + ] lastgaelam = 0 for t in reversed(range(max_seqlen - 1)): nextvalues = values[:, t + 1] @@ -123,9 +125,6 @@ class PPOActor: delta = rewards[:, t] + self.discount * nextvalues - values[:, t] lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam advantages_reversed.append(lastgaelam) - advantages_reversed.append( - torch.zeros(bs, dtype=torch.float32, device=values.device) - ) advantages = torch.stack(advantages_reversed[::-1], dim=1) # Optionally perform advantage normalization. diff --git a/arealite/engine/sft/lm_engine.py b/arealite/engine/sft/lm_engine.py index 0a6bec2..18bee78 100644 --- a/arealite/engine/sft/lm_engine.py +++ b/arealite/engine/sft/lm_engine.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch import torch.utils.data from tensordict import TensorDict @@ -44,9 +42,7 @@ class FSDPLMEngine(FSDPEngine): return self.lm_engine.evaluate_lm(data) -def compute_packed_sft_loss( - logits: torch.Tensor, input_: Dict[str, torch.Tensor] -) -> torch.Tensor: +def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor: packed_input_ids: torch.Tensor = input_["input_ids"] cu_seqlens: torch.Tensor = input_["cu_seqlens"] loss_mask = input_["loss_mask"].bool() diff --git a/arealite/engine/sglang_remote.py b/arealite/engine/sglang_remote.py index c985fa6..4b9eb86 100644 --- a/arealite/engine/sglang_remote.py +++ b/arealite/engine/sglang_remote.py @@ -7,10 +7,12 @@ import traceback from concurrent.futures import ProcessPoolExecutor from datetime import datetime from queue import Empty, Full, Queue -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List +import aiohttp import requests import torch.distributed as dist +import uvloop from tensordict import TensorDict from torchdata.stateful_dataloader import StatefulDataLoader @@ -23,8 +25,8 @@ from arealite.api.io_struct import ( RolloutStat, WeightUpdateMeta, ) -from arealite.utils.http import arequest_with_retry -from arealite.utils.padding import concat_padded_tensors +from arealite.utils.data import concat_padded_tensors +from arealite.utils.http import arequest_with_retry, get_default_connector from realhf.base import logging, name_resolve, names, pkg_version if TYPE_CHECKING: @@ -37,7 +39,7 @@ if pkg_version.is_available("sglang"): else: SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids" -ROLLOUT_POLL_WAIT_TIME = 0.1 +ROLLOUT_POLL_WAIT_TIME = 0.05 RID_CACHE_SIZE = 128 @@ -98,6 +100,7 @@ class RemoteSGLangEngine(InferenceEngine): def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None): self.rollout_tasks: Dict[str, asyncio.Task] = {} + self.executor = ProcessPoolExecutor(max_workers=1) self.rollout_thread = threading.Thread(target=self._rollout_thread) self.rollout_thread.start() @@ -118,32 +121,39 @@ class RemoteSGLangEngine(InferenceEngine): def _rollout_thread(self): """Thread that runs the rollout loop.""" try: - asyncio.run(self._rollout_thread_async()) + uvloop.run(self._rollout_thread_async()) except Exception as e: traceback.print_exc() async def _rollout_thread_async(self): - pending_data = [] rollout_tasks = self.rollout_tasks rid = 0 + + # NOTE: session is not thread-safe, but we only submit requests in the sub-thread. + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.config.request_timeout, + sock_connect=self.config.request_timeout, + connect=self.config.request_timeout, + ), + read_bufsize=1024 * 1024 * 10, + connector=get_default_connector(), + ) + try: while not self.exiting.is_set(): - # Load next data from controller - while True: - try: - data, workflow = self.input_queue.get_nowait() - logger.debug(f"Get data from puller: {data}") - pending_data.append(data) - except Empty: - logger.debug(f"No data from puller stream.") - break - # Check capacity capacity = self.get_capacity() # Create new rollout task - while capacity > 0 and pending_data and not self.paused.is_set(): + while ( + capacity > 0 + and not self.paused.is_set() + and self.input_queue.qsize() > 0 + ): + data, workflow = self.input_queue.get_nowait() + logger.debug(f"Get data from puller: {data}") task = asyncio.create_task( - workflow.arun_episode(self, pending_data.pop(0)), name=str(rid) + workflow.arun_episode(self, data), name=str(rid) ) with self.lock: rollout_tasks[str(rid)] = task @@ -158,7 +168,6 @@ class RemoteSGLangEngine(InferenceEngine): ) capacity -= 1 rid += 1 - # Wait for rollout completion with self.lock: tasks = list(rollout_tasks.values()) @@ -169,11 +178,6 @@ class RemoteSGLangEngine(InferenceEngine): timeout=ROLLOUT_POLL_WAIT_TIME, return_when=asyncio.FIRST_COMPLETED, ) - if not done: - await asyncio.sleep(1) - else: - await asyncio.sleep(1) - # Collect done results for task in done: traj = await task @@ -199,6 +203,7 @@ class RemoteSGLangEngine(InferenceEngine): f"running: {self.rollout_stat.running}, " f"accepted: {self.rollout_stat.accepted}." ) + await asyncio.sleep(1) except Exception: traceback.print_exc() finally: @@ -213,10 +218,11 @@ class RemoteSGLangEngine(InferenceEngine): pass def choose_server(self) -> str: - if self.config.schedule_policy == "round_robin": - server = self.addresses[self.server_idx] - self.server_idx = (self.server_idx + 1) % len(self.addresses) - return server + with self.lock: + if self.config.schedule_policy == "round_robin": + server = self.addresses[self.server_idx] + self.server_idx = (self.server_idx + 1) % len(self.addresses) + return server raise NotImplementedError("Only round-robin scheduling is implemented.") async def agenerate(self, req: LLMRequest) -> LLMResponse: @@ -253,7 +259,6 @@ class RemoteSGLangEngine(InferenceEngine): accumulated_versions = [] # Deal with rollout interruption - completions = "" stop_reason = "length" if req.rid in self.rid_to_address: @@ -273,11 +278,12 @@ class RemoteSGLangEngine(InferenceEngine): ): # loop until the generation is complete result = await arequest_with_retry( - addr=self.choose_server(), + session=self.session, + addr=server_addr, endpoint="/generate", payload=payload, method="POST", - max_retries=3, + max_retries=self.config.request_retries, timeout=self.config.request_timeout, ) @@ -297,10 +303,7 @@ class RemoteSGLangEngine(InferenceEngine): stop_reason = finish_reason["type"] payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER] - sample_params["max_new_tokens"] = min( - sample_params["max_new_tokens"], - gconfig.max_new_tokens - len(output_tokens), - ) + sample_params["max_new_tokens"] -= len(output_tokens) latency = time.perf_counter() - start_time @@ -408,18 +411,24 @@ class RemoteSGLangEngine(InferenceEngine): def prepare_batch( self, - data_generator: Iterator, dataloader: StatefulDataLoader, workflow: "RolloutWorkflow", ): + if not hasattr(self, "data_generator"): + self.data_generator = iter(dataloader) assert dataloader.batch_size is not None while True: - if self.get_capacity() + dataloader.batch_size > 0: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + dataloader.batch_size > 0 + and self.input_queue.qsize() + dataloader.batch_size + < self.input_queue.maxsize + ): try: - data = next(data_generator) + data = next(self.data_generator) except StopIteration: - data_generator = iter(dataloader) - data = next(data_generator) + self.data_generator = iter(dataloader) + data = next(self.data_generator) for item in data: self.submit(item, workflow=workflow) try: @@ -435,10 +444,12 @@ class RemoteSGLangEngine(InferenceEngine): async def aupdate_weights_from_disk( - addr, path: str, request_retries: int, request_timeout: float + session, addr, path: str, request_retries: int, request_timeout: float ): + tik = time.time() res = await arequest_with_retry( addr=addr, + session=session, endpoint="/update_weights_from_disk", payload=dict(model_path=str(path), allow_interrupt=True), method="POST", @@ -472,9 +483,19 @@ def update_weights_from_disk( logger.info( f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s" ) + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=request_timeout, + sock_connect=request_timeout, + connect=request_timeout, + ), + read_bufsize=1024 * 1024 * 10, + connector=get_default_connector(), + ) jobs = [ aupdate_weights_from_disk( - addr, + session=session, + addr=addr, path=path, request_retries=request_retries, request_timeout=request_timeout, @@ -482,8 +503,9 @@ def update_weights_from_disk( for addr in addresses ] await asyncio.gather(*jobs) + await session.close() logger.info( f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s" ) - return asyncio.run(_fn()) + return uvloop.run(_fn()) diff --git a/arealite/tests/test_sglang_engine.py b/arealite/tests/test_sglang_engine.py index 8c8f458..71a6f2f 100644 --- a/arealite/tests/test_sglang_engine.py +++ b/arealite/tests/test_sglang_engine.py @@ -5,6 +5,7 @@ import time import uuid import pytest +import requests import torch from tensordict import TensorDict @@ -28,6 +29,17 @@ HOST = network.gethostip() RUN_SERVER_TIMEOUT = 180 +def check_server_health(base_url): + try: + response = requests.get( + f"{base_url}/metrics", + timeout=30, + ) + return response.status_code == 200 + except requests.exceptions.RequestException as e: + return False + + @pytest.fixture(scope="module") def sglang_server(): from realhf.base import seeding diff --git a/arealite/tests/test_sglang_local_engine.py b/arealite/tests/test_sglang_local_engine.py index 2cf4dce..faebc7a 100644 --- a/arealite/tests/test_sglang_local_engine.py +++ b/arealite/tests/test_sglang_local_engine.py @@ -67,7 +67,6 @@ async def test_local_sglang_generate(): gconfig=GenerationHyperparameters(max_new_tokens=16), ) resp = await engine.agenerate(req) - print(resp.completions) assert isinstance(resp, LLMResponse) assert resp.input_tokens == req.input_ids @@ -76,9 +75,6 @@ async def test_local_sglang_generate(): == len(resp.output_tokens) == len(resp.output_versions) ) - assert isinstance(resp.completions, str) - - time.sleep(5) engine.destroy() diff --git a/arealite/tests/test_engine.py b/arealite/tests/test_train_engine.py similarity index 92% rename from arealite/tests/test_engine.py rename to arealite/tests/test_train_engine.py index 6c5d07a..9c208b0 100644 --- a/arealite/tests/test_engine.py +++ b/arealite/tests/test_train_engine.py @@ -13,7 +13,6 @@ from transformers import AutoTokenizer from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta -from arealite.engine.fsdp_engine import FSDPEngine VOCAB_SIZE = 100 MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/" @@ -53,10 +52,10 @@ def mock_input( def get_engine(engine_type: str, model_path: str): + from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine from arealite.engine.fsdp_engine import FSDPEngine - from arealite.engine.hf_engine import HFEngine - engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type] + engine_cls = {"auto_tp": DeepSpeedAutoTPEngine, "fsdp": FSDPEngine}[engine_type] engine_config = TrainEngineConfig( experiment_name=f"test-{engine_type}-engine", @@ -75,7 +74,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor: return torch.mean(logits) -@pytest.fixture(scope="module", params=["fsdp", "hf"]) +@pytest.fixture(scope="module", params=["fsdp", "auto_tp"]) def engine(request): os.environ.update( { @@ -136,6 +135,12 @@ def test_train_batch(engine, mock_input): @torch.no_grad() def test_hf_save_load_weights(tmp_path_factory, engine, mock_input): + from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine + + if isinstance(engine, DeepSpeedAutoTPEngine): + print("AutoTP engine does not support HF save/load for now.") + return + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) path = tmp_path_factory.mktemp("hf_engine_test") save_load_meta = SaveLoadMeta( diff --git a/arealite/utils/data.py b/arealite/utils/data.py index f6572d1..b5f3174 100644 --- a/arealite/utils/data.py +++ b/arealite/utils/data.py @@ -92,9 +92,7 @@ def unpad_input( seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( rearrange(hidden_states, "b s ... -> (b s) ...")[indices], indices, @@ -116,16 +114,12 @@ def concat_padded_tensors( if not tensor_dicts: return TensorDict() - # Find max sequence length across all dictionaries - lens = [] - for tensor_dict in tensor_dicts: - for key, tensor in tensor_dict.items(): - if key != "attention_mask" and len(tensor.shape) == 2: - lens.append(tensor.shape[1]) - break - max_length = max(lens) - attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1) + batch_sizes = [tuple(d.batch_size) for d in tensor_dicts] + new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]] + # Find max sequence length across all dictionaries + assert all("attention_mask" in td for td in tensor_dicts) + max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts]) result = {} # Process each key for key in tensor_dicts[0].keys(): @@ -154,9 +148,7 @@ def concat_padded_tensors( tensors_to_concat.append(tensor) result[key] = torch.cat(tensors_to_concat, dim=0) - if "attention_mask" not in result: - result["attention_mask"] = attn_mask - return TensorDict(result, batch_size=[len(lens)]) + return TensorDict(result, batch_size=new_batch_size) def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]: @@ -231,7 +223,7 @@ def pack_tensor_dict(data: TensorDict): cu_seqlens = torch.cumsum(lens, dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - total_length = cu_seqlens[-1].item() + total_length = int(cu_seqlens[-1].item()) # Pack tensors packed_data = {} for key, value in data.items(): @@ -246,7 +238,7 @@ def pack_tensor_dict(data: TensorDict): and value.shape[1] == seq_len ): packed_tensor = torch.empty( - total_length, *value.shape[2:], dtype=value.dtype, device=value.device + (total_length, *value.shape[2:]), dtype=value.dtype, device=value.device ) # Fill the packed tensor with values from the original tensor for i in range(bs): @@ -363,7 +355,7 @@ def split_padded_tensor_dict_into_mb_list( mbs=results, mb_spec=mb_spec, forward_indices=forward_indices, - backward_indices=backward_indices, + backward_indices=backward_indices.tolist(), group_lens=group_lens, ) diff --git a/arealite/utils/functional.py b/arealite/utils/functional.py index 9ce736c..3abafa8 100644 --- a/arealite/utils/functional.py +++ b/arealite/utils/functional.py @@ -121,19 +121,24 @@ def masked_normalization( def ppo_actor_loss_fn( logprobs: torch.Tensor, + proximal_logprobs: torch.Tensor, old_logprobs: torch.Tensor, advantages: torch.Tensor, eps_clip: float, loss_mask: torch.Tensor, c_clip: Optional[float] = None, - proximal_logprobs: Optional[torch.Tensor] = None, behav_imp_weight_cap: Optional[float] = None, ) -> Tuple[torch.Tensor, Dict]: - denorm_logprobs = ( - proximal_logprobs if proximal_logprobs is not None else old_logprobs - ) + """ + When decoupled loss is disabled: + 1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp; + 2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend. + + When decoupled loss is enabled, proximal_logprobs is the recomputed logp, + old_logprobs is produced by the inference engine. + """ loss_mask_count = loss_mask.count_nonzero() or 1 - ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0) + ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0) clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) pg_loss1 = -advantages * ratio pg_loss2 = -advantages * clipped_ratio @@ -146,17 +151,16 @@ def ppo_actor_loss_fn( pg_loss = torch.min(pg_loss, pg_loss3) else: dual_clip_mask = torch.zeros_like(clip_mask) - if proximal_logprobs is not None: - behav_kl = proximal_logprobs - old_logprobs - behav_imp_weight = behav_kl.exp() - behav_mask = ( - (behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask) - if behav_imp_weight_cap is not None - else loss_mask - ) - behav_kl = torch.where(behav_mask, behav_kl, 0.0) - behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0) - pg_loss = pg_loss * behav_imp_weight + behav_kl = proximal_logprobs - old_logprobs + behav_imp_weight = behav_kl.exp() + behav_mask = ( + (behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask) + if behav_imp_weight_cap is not None + else loss_mask + ) + behav_kl = torch.where(behav_mask, behav_kl, 0.0) + behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0) + pg_loss = pg_loss * behav_imp_weight logging_loss = pg_loss.detach() pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count clip_mask.logical_and_(loss_mask) @@ -164,7 +168,7 @@ def ppo_actor_loss_fn( stat = dict( loss=logging_loss, importance_weight=ratio.detach(), - approx_kl=(logprobs - denorm_logprobs).detach(), + approx_kl=(logprobs - proximal_logprobs).detach(), clip_mask=clip_mask, dual_clip_mask=dual_clip_mask, ) diff --git a/arealite/utils/http.py b/arealite/utils/http.py index 29e4df4..a39a361 100644 --- a/arealite/utils/http.py +++ b/arealite/utils/http.py @@ -3,45 +3,74 @@ from typing import Any, Dict, Optional import aiohttp +from realhf.base import logging + DEFAULT_RETRIES = 1 DEFAULT_REQUEST_TIMEOUT = 3600 +logger = logging.getLogger(__file__) + + +def get_default_connector(): + return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True) + + async def arequest_with_retry( addr: str, endpoint: str, payload: Optional[Dict[str, Any]] = None, + session: aiohttp.ClientSession | None = None, method: str = "POST", max_retries: Optional[int] = None, timeout: Optional[float] = None, retry_delay: float = 1.0, -) -> aiohttp.ClientResponse: + verbose=False, +) -> Dict: timeout = timeout or DEFAULT_REQUEST_TIMEOUT last_exception = None max_retries = max_retries or DEFAULT_RETRIES base_url = f"http://{addr}" url = f"{base_url}{endpoint}" + timeo = aiohttp.ClientTimeout( + total=timeout, + sock_connect=timeout, + connect=timeout, + ) + if session is None: + _session = aiohttp.ClientSession( + timeout=timeo, + read_bufsize=1024 * 1024 * 10, + connector=get_default_connector(), + ) + else: + _session = session + for attempt in range(max_retries): try: - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=timeout, - sock_connect=timeout, - ) - ) as session: - if method.upper() == "GET": - response = await session.get(url) - elif method.upper() == "POST": - response = await session.post(url, json=payload) - elif method.upper() == "PUT": - response = await session.put(url, json=payload) - elif method.upper() == "DELETE": - response = await session.delete(url) - else: - raise ValueError(f"Unsupported HTTP method: {method}") + if verbose: + logger.info("enter client session, start sending requests") + if method.upper() == "GET": + ctx = _session.get(url, timeout=timeo) + elif method.upper() == "POST": + ctx = _session.post(url, json=payload, timeout=timeo) + elif method.upper() == "PUT": + ctx = _session.put(url, json=payload, timeout=timeo) + elif method.upper() == "DELETE": + ctx = _session.delete(url, timeout=timeo) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + async with ctx as response: + if verbose: + logger.info("http requests return") response.raise_for_status() - return await response.json() + res = await response.json() + if verbose: + logger.info("get http result") + if session is None: + await _session.close() + return res except ( aiohttp.ClientError, aiohttp.ClientResponseError, @@ -51,6 +80,10 @@ async def arequest_with_retry( if attempt < max_retries - 1: await asyncio.sleep(retry_delay) continue + if session is None: + await _session.close() raise RuntimeError( - f"Failed after {max_retries} retries each. " f"Last error: {last_exception}" + f"Failed after {max_retries} retries each. " + f"Payload: {payload}. Addr: {addr}. Endpoint: {endpoint}. " + f"Last error: {last_exception}" ) diff --git a/arealite/workflow/rlvr.py b/arealite/workflow/rlvr.py index 026f574..3ce55df 100644 --- a/arealite/workflow/rlvr.py +++ b/arealite/workflow/rlvr.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizerFast from arealite.api.cli_args import GenerationHyperparameters from arealite.api.io_struct import LLMRequest from arealite.api.workflow_api import RolloutWorkflow -from arealite.utils.padding import concat_padded_tensors +from arealite.utils.data import concat_padded_tensors class RLVRWorkflow(RolloutWorkflow): @@ -61,7 +61,7 @@ class RLVRWorkflow(RolloutWorkflow): versions=torch.tensor(versions).unsqueeze(0), attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), # reward - rewards=torch.tensor([reward]), + rewards=torch.tensor([float(reward)]), ) results.append(TensorDict(res, batch_size=[1])) diff --git a/examples/arealite/gsm8k_grpo.py b/examples/arealite/gsm8k_grpo.py index 07434a0..1841343 100644 --- a/examples/arealite/gsm8k_grpo.py +++ b/examples/arealite/gsm8k_grpo.py @@ -152,11 +152,7 @@ def main_grpo(): with stats_tracker.record_timing("rollout"): if config.async_training: - batch = rollout.prepare_batch( - data_generator, - train_dataloader, - workflow=workflow, - ) + batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: try: data = next(data_generator) @@ -210,8 +206,9 @@ def main_grpo(): actor.upload_weights(meta) if dist.get_rank() == 0: future.result() - rollout.set_version(global_step + 1) dist.barrier() + torch.cuda.synchronize() + rollout.set_version(global_step + 1) with stats_tracker.record_timing("save"): saver.save(actor, epoch, step, global_step) diff --git a/realhf/base/name_resolve.py b/realhf/base/name_resolve.py index 9c6ec25..ef9879e 100644 --- a/realhf/base/name_resolve.py +++ b/realhf/base/name_resolve.py @@ -1360,7 +1360,9 @@ class RayNameResolveRepository: def make_repository(args: "NameResolveConfig"): if args.type == "nfs": - return NfsNameRecordRepository(args.nfs_record_root) + repo = NfsNameRecordRepository(args.nfs_record_root) + os.makedirs(repo.record_root, exist_ok=True) + return repo elif args.type == "etcd3": host, port = args.etcd3_addr.split(":") return Etcd3NameRecordRepository(host=host, port=int(port))