[Fix] [lite] Merge from the internal repo to fix GRPO bugs and refactor the train engine (#181)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-16 17:26:49 +08:00 committed by GitHub
parent 0283cfa124
commit 29e164a69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 790 additions and 906 deletions

View File

@ -18,7 +18,7 @@ from arealite.utils.fs import get_user_tmp
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
n_mbs: int = field(
n_mbs: Optional[int] = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
@ -161,7 +161,7 @@ class FSDPEngineConfig:
@dataclass
class HFEngineConfig:
class DeepSpeedAutoTPEngineConfig:
autotp_size: Optional[int] = field(
default=1,
metadata={"help": "DeepSpeed AutoTP size"},
@ -201,7 +201,88 @@ class TrainEngineConfig:
)
backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
ds_auto_tp: DeepSpeedAutoTPEngineConfig = field(
default_factory=DeepSpeedAutoTPEngineConfig
)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass

View File

@ -25,10 +25,10 @@ class Scheduling:
cpu: int
gpu: int
mem: int
nodelist: str = None
exclude: str = None
partition: str = None
container_image: str = None
nodelist: Optional[str] = None
exclude: Optional[str] = None
partition: Optional[str] = None
container_image: Optional[str] = None
env_vars: Dict[str, str] = field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format
@ -105,7 +105,7 @@ class TrainEngine(abc.ABC):
def forward(
self,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:

View File

@ -71,7 +71,7 @@ class AllocationType(enum.Enum):
@dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: None | Dict[str, Dict[str, int]]
parallel_strat: Dict[str, Dict[str, int]]
@property
def gen_tp_size(self) -> int:
@ -115,7 +115,7 @@ class AllocationMode:
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
@staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
def extract_parallelism_strategy(allocation_mode: str) -> Dict:
for x, y, z in itertools.permutations(["d", "t", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
@ -130,29 +130,28 @@ class AllocationMode:
z: c,
}
}
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:
return
raise ValueError(
f"Unknown how to resolve decoupled allocation: {allocation_mode}"
)
if m.group(1):
gen_alloc = m.group(1)
other_alloc = m.group(2)
else:
gen_alloc = m.group(4)
other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
if not gen_alloc:
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc)
other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc)
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc
@ -171,7 +170,7 @@ class SaveLoadMeta:
path: str
weight_format: str
with_optim: bool
tokenizer: PreTrainedTokenizerFast | None
tokenizer: Optional[PreTrainedTokenizerFast]
base_model_path: str | None
naive_distributed: bool = False

View File

@ -0,0 +1,128 @@
import os
import torch
import torch.distributed as dist
from safetensors.torch import save_file
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
from realhf.base import constants, logging
logger = logging.getLogger("DeepSpeedAutoTPEngine")
class DeepSpeedAutoTPEngine(BaseHFEngine):
def __init__(self, config: TrainEngineConfig):
super().__init__(config)
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert (
addr is None
), "DeepSpeedAutoTPEngine does not support remote initialization."
import deepspeed
self.create_process_group()
world_size = int(os.environ.get("WORLD_SIZE"))
deepspeed.init_distributed(
dist_backend="nccl",
world_size=world_size,
timeout=constants.NCCL_DEFAULT_TIMEOUT,
)
self.create_device_model()
# NOTE: the device context manager does not work here.
self.model = deepspeed.tp_model_init(
self.model,
tp_size=self.config.ds_auto_tp.autotp_size,
dtype=getattr(torch, self.config.dtype),
).to(self.device)
self.create_optimizer(ft_spec)
self.initialized = True
def _check_autotp(self):
tp_size = self.config.ds_auto_tp.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)
def save(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if self.model is None:
raise RuntimeError("Model not initialized")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(meta.path, exist_ok=True)
self.model_config.save_pretrained(
meta.path,
)
if meta.tokenizer is not None:
meta.tokenizer.save_pretrained(
meta.path,
)
state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}
# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
if meta.with_optim:
self.save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(meta.path):
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(meta.path)
if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
if meta.with_optim:
self.load_optimizer_state(meta.path)
def upload_weights(self, meta: WeightUpdateMeta):
raise ValueError(f"update weight not implemented {meta.type}")

View File

@ -0,0 +1,360 @@
import gc
import os
import time
from typing import Any, Callable, Dict, List
import torch
import torch.distributed as dist
from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, TrainEngine
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine")
class BaseHFEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model: torch.nn.Module
self.optimizer: torch.optim.Optimizer
self.tokenizer: PreTrainedTokenizerFast
# huggingface model config
self.model_config: PretrainedConfig
# initialization
self.initialized = False
self.own_global_group = False
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
@property
def parallelism_group(self) -> dist.ProcessGroup:
assert self.initialized
return self._parallelism_group
def create_process_group(self):
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()
def create_device_model(self):
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
self.model = model
def create_optimizer(self, ft_spec: FinetuneSpec):
if self.optimizer_config is None:
return
assert self.model is not None
# Set up optimizer
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
def destroy(self):
"""Destroy the engine and release GPU memory."""
del self.optimizer
del self.model
gc.collect()
torch.cuda.empty_cache()
gc.collect()
dist.destroy_process_group(self.parallelism_group)
if self.own_global_group:
dist.destroy_process_group()
self.initialized = False
def save_optimizer_state(self, path: str):
# Save FSDP sharded state dict on each rank
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
def load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
for mb in mb_list.padded_mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
# Scale loss for accumulation
# Revert gradient averaging across dp ranks
loss_scale *= self.world_size
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
input_ = input_.to(self.device)
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self.prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -1,42 +1,20 @@
import gc
import os
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, Optional
import torch
import torch.distributed as dist
import transformers
from tensordict import TensorDict
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import (
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
@ -44,108 +22,35 @@ from arealite.utils.fsdp import (
create_fsdp_device_mesh,
fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
)
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import constants, logging, name_resolve, names, pkg_version
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
class FSDPEngine(TrainEngine):
class FSDPEngine(BaseHFEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
super().__init__(config)
# FSDP options
self.mixed_precision_policy = None
self.device_mesh = None
self.cpu_offload = None
# initialization
self.initialized = False
self.own_global_group = False
self._parallelism_group = None
self.weight_update_group_initialized = False
# TODO: Handle the case when WORLD_SIZE is not set in launcher
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
@property
def parallelism_group(self) -> dist.ProcessGroup:
assert self.initialized
return self._parallelism_group
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
# Initialize distributed enviroments and load model.
assert addr is None, "FSDPEngine does not support remote initialization."
assert pkg_version.is_version_greater_or_equal(
"torch", "2.4.0"
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
"""Initialize distributed communication and model."""
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
self.create_process_group()
self.create_device_model()
# Wrap with FSDP2
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=dtype,
param_dtype=getattr(torch, self.config.dtype),
reduce_dtype=torch.float32,
cast_forward_inputs=True,
)
@ -154,82 +59,19 @@ class FSDPEngine(TrainEngine):
self.cpu_offload = (
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
)
fsdp_kwargs = {
"mesh": self.device_mesh,
"mp_policy": self.mixed_precision_policy,
"offload_policy": self.cpu_offload,
"reshard_after_forward": True,
}
# Wrap with FSDP2
tik = time.perf_counter()
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy)
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
self.model = model
# Set up optimizer
if self.optimizer_config is not None:
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
self.create_optimizer(ft_spec)
self.initialized = True
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
dist.destroy_process_group(self.parallelism_group)
if self.own_global_group:
dist.destroy_process_group()
self.initialized = False
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer)
@ -240,7 +82,7 @@ class FSDPEngine(TrainEngine):
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._save_optimizer_state(meta.path)
self.save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
@ -252,34 +94,10 @@ class FSDPEngine(TrainEngine):
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
# Save FSDP sharded state dict on each rank
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
def _load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
self.load_optimizer_state(meta.path)
def _save_model_to_hf(
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
):
"""Save model in HuggingFace format."""
if self.model is None:
@ -345,35 +163,11 @@ class FSDPEngine(TrainEngine):
"Distributed weight update is not implemented for FSDPEngine yet. "
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
@ -382,7 +176,7 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_)
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
@ -394,7 +188,6 @@ class FSDPEngine(TrainEngine):
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
@ -409,6 +202,7 @@ class FSDPEngine(TrainEngine):
loss *= loss_scale
loss.backward()
# NOTE: grad norm clip function is different
grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
)
@ -427,72 +221,3 @@ class FSDPEngine(TrainEngine):
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
input_ = input_.to(self.device)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -1,467 +0,0 @@
import gc
import os
import time
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from safetensors.torch import save_file
from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("HFEngine")
class HFEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
# initialization
self.initialized = False
self.weight_update_group_initialized = False
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert addr is None, "HFEngine does not support remote initialization."
world_size = int(os.environ.get("WORLD_SIZE", 0))
if not dist.is_initialized() and world_size > 1:
try:
import deepspeed
except ImportError:
print(
"Warning: deepspeed is not installed. Some functionality may be disabled."
)
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
self.device = torch.device(f"cuda:{local_rank}")
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
self.model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
).to(f"cuda:{local_rank}")
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
naive_distributed=False,
)
self.load(load_meta)
if world_size > 1:
if self._check_autotp():
self.model = deepspeed.tp_model_init(
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
)
else:
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
# Set up optimizer
if self.optimizer_config is not None:
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
self.initialized = True
def _check_autotp(self):
tp_size = self.config.hf.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
self.initialized = False
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format saving is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format loading is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
assert self.optimizer is not None
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
def _load_optimizer_state(self, path: str):
assert self.optimizer is not None
path = os.path.join(path, "optim.pt")
optimizer_state_dict = torch.load(path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
def _save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
naive_distributed: bool,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(path, exist_ok=True)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if world_size > 1:
dist.barrier()
state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}
if world_size > 1 and naive_distributed:
# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
else:
self.model.save_pretrained(path, state_dict=state_dict)
if world_size > 1:
dist.barrier()
def _load_model_from_hf(self, path: str, naive_distributed: bool):
"""Load model from HuggingFace format."""
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(path):
if naive_distributed:
path = f"{path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(path)
if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def upload_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
return self.lr_scheduler.step()
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
return mb_list
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -114,7 +114,9 @@ class PPOActor:
values = torch.zeros_like(rewards)
else:
values = data["values"]
advantages_reversed = []
advantages_reversed = [
torch.zeros(bs, dtype=torch.float32, device=values.device)
]
lastgaelam = 0
for t in reversed(range(max_seqlen - 1)):
nextvalues = values[:, t + 1]
@ -123,9 +125,6 @@ class PPOActor:
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
advantages_reversed.append(lastgaelam)
advantages_reversed.append(
torch.zeros(bs, dtype=torch.float32, device=values.device)
)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization.

View File

@ -1,5 +1,3 @@
from typing import Dict
import torch
import torch.utils.data
from tensordict import TensorDict
@ -44,9 +42,7 @@ class FSDPLMEngine(FSDPEngine):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
) -> torch.Tensor:
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
loss_mask = input_["loss_mask"].bool()

View File

@ -7,10 +7,12 @@ import traceback
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List
from typing import TYPE_CHECKING, Any, Callable, Dict, List
import aiohttp
import requests
import torch.distributed as dist
import uvloop
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
@ -23,8 +25,8 @@ from arealite.api.io_struct import (
RolloutStat,
WeightUpdateMeta,
)
from arealite.utils.http import arequest_with_retry
from arealite.utils.padding import concat_padded_tensors
from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING:
@ -37,7 +39,7 @@ if pkg_version.is_available("sglang"):
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.1
ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128
@ -98,6 +100,7 @@ class RemoteSGLangEngine(InferenceEngine):
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.executor = ProcessPoolExecutor(max_workers=1)
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
@ -118,32 +121,39 @@ class RemoteSGLangEngine(InferenceEngine):
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
asyncio.run(self._rollout_thread_async())
uvloop.run(self._rollout_thread_async())
except Exception as e:
traceback.print_exc()
async def _rollout_thread_async(self):
pending_data = []
rollout_tasks = self.rollout_tasks
rid = 0
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
try:
while not self.exiting.is_set():
# Load next data from controller
while True:
try:
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
pending_data.append(data)
except Empty:
logger.debug(f"No data from puller stream.")
break
# Check capacity
capacity = self.get_capacity()
# Create new rollout task
while capacity > 0 and pending_data and not self.paused.is_set():
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task(
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
workflow.arun_episode(self, data), name=str(rid)
)
with self.lock:
rollout_tasks[str(rid)] = task
@ -158,7 +168,6 @@ class RemoteSGLangEngine(InferenceEngine):
)
capacity -= 1
rid += 1
# Wait for rollout completion
with self.lock:
tasks = list(rollout_tasks.values())
@ -169,11 +178,6 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
await asyncio.sleep(1)
else:
await asyncio.sleep(1)
# Collect done results
for task in done:
traj = await task
@ -199,6 +203,7 @@ class RemoteSGLangEngine(InferenceEngine):
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
await asyncio.sleep(1)
except Exception:
traceback.print_exc()
finally:
@ -213,6 +218,7 @@ class RemoteSGLangEngine(InferenceEngine):
pass
def choose_server(self) -> str:
with self.lock:
if self.config.schedule_policy == "round_robin":
server = self.addresses[self.server_idx]
self.server_idx = (self.server_idx + 1) % len(self.addresses)
@ -253,7 +259,6 @@ class RemoteSGLangEngine(InferenceEngine):
accumulated_versions = []
# Deal with rollout interruption
completions = ""
stop_reason = "length"
if req.rid in self.rid_to_address:
@ -273,11 +278,12 @@ class RemoteSGLangEngine(InferenceEngine):
):
# loop until the generation is complete
result = await arequest_with_retry(
addr=self.choose_server(),
session=self.session,
addr=server_addr,
endpoint="/generate",
payload=payload,
method="POST",
max_retries=3,
max_retries=self.config.request_retries,
timeout=self.config.request_timeout,
)
@ -297,10 +303,7 @@ class RemoteSGLangEngine(InferenceEngine):
stop_reason = finish_reason["type"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
sample_params["max_new_tokens"] = min(
sample_params["max_new_tokens"],
gconfig.max_new_tokens - len(output_tokens),
)
sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time
@ -408,18 +411,24 @@ class RemoteSGLangEngine(InferenceEngine):
def prepare_batch(
self,
data_generator: Iterator,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
assert dataloader.batch_size is not None
while True:
if self.get_capacity() + dataloader.batch_size > 0:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(data_generator)
data = next(self.data_generator)
except StopIteration:
data_generator = iter(dataloader)
data = next(data_generator)
self.data_generator = iter(dataloader)
data = next(self.data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
@ -435,10 +444,12 @@ class RemoteSGLangEngine(InferenceEngine):
async def aupdate_weights_from_disk(
addr, path: str, request_retries: int, request_timeout: float
session, addr, path: str, request_retries: int, request_timeout: float
):
tik = time.time()
res = await arequest_with_retry(
addr=addr,
session=session,
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
@ -472,9 +483,19 @@ def update_weights_from_disk(
logger.info(
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
)
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=request_timeout,
sock_connect=request_timeout,
connect=request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
jobs = [
aupdate_weights_from_disk(
addr,
session=session,
addr=addr,
path=path,
request_retries=request_retries,
request_timeout=request_timeout,
@ -482,8 +503,9 @@ def update_weights_from_disk(
for addr in addresses
]
await asyncio.gather(*jobs)
await session.close()
logger.info(
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
)
return asyncio.run(_fn())
return uvloop.run(_fn())

View File

@ -5,6 +5,7 @@ import time
import uuid
import pytest
import requests
import torch
from tensordict import TensorDict
@ -28,6 +29,17 @@ HOST = network.gethostip()
RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
@pytest.fixture(scope="module")
def sglang_server():
from realhf.base import seeding

View File

@ -67,7 +67,6 @@ async def test_local_sglang_generate():
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
print(resp.completions)
assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids
@ -76,9 +75,6 @@ async def test_local_sglang_generate():
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completions, str)
time.sleep(5)
engine.destroy()

View File

@ -13,7 +13,6 @@ from transformers import AutoTokenizer
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
from arealite.engine.fsdp_engine import FSDPEngine
VOCAB_SIZE = 100
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
@ -53,10 +52,10 @@ def mock_input(
def get_engine(engine_type: str, model_path: str):
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.hf_engine import HFEngine
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type]
engine_cls = {"auto_tp": DeepSpeedAutoTPEngine, "fsdp": FSDPEngine}[engine_type]
engine_config = TrainEngineConfig(
experiment_name=f"test-{engine_type}-engine",
@ -75,7 +74,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
return torch.mean(logits)
@pytest.fixture(scope="module", params=["fsdp", "hf"])
@pytest.fixture(scope="module", params=["fsdp", "auto_tp"])
def engine(request):
os.environ.update(
{
@ -136,6 +135,12 @@ def test_train_batch(engine, mock_input):
@torch.no_grad()
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
if isinstance(engine, DeepSpeedAutoTPEngine):
print("AutoTP engine does not support HF save/load for now.")
return
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
path = tmp_path_factory.mktemp("hf_engine_test")
save_load_meta = SaveLoadMeta(

View File

@ -92,9 +92,7 @@ def unpad_input(
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
indices,
@ -116,16 +114,12 @@ def concat_padded_tensors(
if not tensor_dicts:
return TensorDict()
# Find max sequence length across all dictionaries
lens = []
for tensor_dict in tensor_dicts:
for key, tensor in tensor_dict.items():
if key != "attention_mask" and len(tensor.shape) == 2:
lens.append(tensor.shape[1])
break
max_length = max(lens)
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
# Find max sequence length across all dictionaries
assert all("attention_mask" in td for td in tensor_dicts)
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
result = {}
# Process each key
for key in tensor_dicts[0].keys():
@ -154,9 +148,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result:
result["attention_mask"] = attn_mask
return TensorDict(result, batch_size=[len(lens)])
return TensorDict(result, batch_size=new_batch_size)
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
@ -231,7 +223,7 @@ def pack_tensor_dict(data: TensorDict):
cu_seqlens = torch.cumsum(lens, dim=0)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
total_length = cu_seqlens[-1].item()
total_length = int(cu_seqlens[-1].item())
# Pack tensors
packed_data = {}
for key, value in data.items():
@ -246,7 +238,7 @@ def pack_tensor_dict(data: TensorDict):
and value.shape[1] == seq_len
):
packed_tensor = torch.empty(
total_length, *value.shape[2:], dtype=value.dtype, device=value.device
(total_length, *value.shape[2:]), dtype=value.dtype, device=value.device
)
# Fill the packed tensor with values from the original tensor
for i in range(bs):
@ -363,7 +355,7 @@ def split_padded_tensor_dict_into_mb_list(
mbs=results,
mb_spec=mb_spec,
forward_indices=forward_indices,
backward_indices=backward_indices,
backward_indices=backward_indices.tolist(),
group_lens=group_lens,
)

View File

@ -121,19 +121,24 @@ def masked_normalization(
def ppo_actor_loss_fn(
logprobs: torch.Tensor,
proximal_logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
loss_mask: torch.Tensor,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.Tensor] = None,
behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
denorm_logprobs = (
proximal_logprobs if proximal_logprobs is not None else old_logprobs
)
"""
When decoupled loss is disabled:
1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp;
2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend.
When decoupled loss is enabled, proximal_logprobs is the recomputed logp,
old_logprobs is produced by the inference engine.
"""
loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
@ -146,7 +151,6 @@ def ppo_actor_loss_fn(
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_mask = (
@ -164,7 +168,7 @@ def ppo_actor_loss_fn(
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - denorm_logprobs).detach(),
approx_kl=(logprobs - proximal_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)

View File

@ -3,45 +3,74 @@ from typing import Any, Dict, Optional
import aiohttp
from realhf.base import logging
DEFAULT_RETRIES = 1
DEFAULT_REQUEST_TIMEOUT = 3600
logger = logging.getLogger(__file__)
def get_default_connector():
return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True)
async def arequest_with_retry(
addr: str,
endpoint: str,
payload: Optional[Dict[str, Any]] = None,
session: aiohttp.ClientSession | None = None,
method: str = "POST",
max_retries: Optional[int] = None,
timeout: Optional[float] = None,
retry_delay: float = 1.0,
) -> aiohttp.ClientResponse:
verbose=False,
) -> Dict:
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
last_exception = None
max_retries = max_retries or DEFAULT_RETRIES
base_url = f"http://{addr}"
url = f"{base_url}{endpoint}"
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
timeo = aiohttp.ClientTimeout(
total=timeout,
sock_connect=timeout,
connect=timeout,
)
) as session:
if session is None:
_session = aiohttp.ClientSession(
timeout=timeo,
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
else:
_session = session
for attempt in range(max_retries):
try:
if verbose:
logger.info("enter client session, start sending requests")
if method.upper() == "GET":
response = await session.get(url)
ctx = _session.get(url, timeout=timeo)
elif method.upper() == "POST":
response = await session.post(url, json=payload)
ctx = _session.post(url, json=payload, timeout=timeo)
elif method.upper() == "PUT":
response = await session.put(url, json=payload)
ctx = _session.put(url, json=payload, timeout=timeo)
elif method.upper() == "DELETE":
response = await session.delete(url)
ctx = _session.delete(url, timeout=timeo)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
async with ctx as response:
if verbose:
logger.info("http requests return")
response.raise_for_status()
return await response.json()
res = await response.json()
if verbose:
logger.info("get http result")
if session is None:
await _session.close()
return res
except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
@ -51,6 +80,10 @@ async def arequest_with_retry(
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
if session is None:
await _session.close()
raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
f"Failed after {max_retries} retries each. "
f"Payload: {payload}. Addr: {addr}. Endpoint: {endpoint}. "
f"Last error: {last_exception}"
)

View File

@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.padding import concat_padded_tensors
from arealite.utils.data import concat_padded_tensors
class RLVRWorkflow(RolloutWorkflow):
@ -61,7 +61,7 @@ class RLVRWorkflow(RolloutWorkflow):
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
rewards=torch.tensor([float(reward)]),
)
results.append(TensorDict(res, batch_size=[1]))

View File

@ -152,11 +152,7 @@ def main_grpo():
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(
data_generator,
train_dataloader,
workflow=workflow,
)
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
@ -210,8 +206,9 @@ def main_grpo():
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
rollout.set_version(global_step + 1)
dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)

View File

@ -1360,7 +1360,9 @@ class RayNameResolveRepository:
def make_repository(args: "NameResolveConfig"):
if args.type == "nfs":
return NfsNameRecordRepository(args.nfs_record_root)
repo = NfsNameRecordRepository(args.nfs_record_root)
os.makedirs(repo.record_root, exist_ok=True)
return repo
elif args.type == "etcd3":
host, port = args.etcd3_addr.split(":")
return Etcd3NameRecordRepository(host=host, port=int(port))