mirror of https://github.com/inclusionAI/AReaL
[Fix] [lite] Merge from the internal repo to fix GRPO bugs and refactor the train engine (#181)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
0283cfa124
commit
29e164a69d
|
@ -18,7 +18,7 @@ from arealite.utils.fs import get_user_tmp
|
||||||
class MicroBatchSpec:
|
class MicroBatchSpec:
|
||||||
"""Specification for splitting micro-batches during training."""
|
"""Specification for splitting micro-batches during training."""
|
||||||
|
|
||||||
n_mbs: int = field(
|
n_mbs: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
||||||
|
@ -161,7 +161,7 @@ class FSDPEngineConfig:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HFEngineConfig:
|
class DeepSpeedAutoTPEngineConfig:
|
||||||
autotp_size: Optional[int] = field(
|
autotp_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "DeepSpeed AutoTP size"},
|
metadata={"help": "DeepSpeed AutoTP size"},
|
||||||
|
@ -201,7 +201,88 @@ class TrainEngineConfig:
|
||||||
)
|
)
|
||||||
backend: str = ""
|
backend: str = ""
|
||||||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||||
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
|
ds_auto_tp: DeepSpeedAutoTPEngineConfig = field(
|
||||||
|
default_factory=DeepSpeedAutoTPEngineConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PPOActorConfig(TrainEngineConfig):
|
||||||
|
# Core PPO/GRPO Parameters
|
||||||
|
group_size: int = field(
|
||||||
|
default=1, metadata={"help": "Number of sequences in each group"}
|
||||||
|
)
|
||||||
|
group_adv_norm: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Normalize advantages within each prompt group rather than globally"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ppo_n_minibatches: int = field(
|
||||||
|
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||||
|
)
|
||||||
|
eps_clip: float = field(
|
||||||
|
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||||
|
)
|
||||||
|
c_clip: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
temperature: float = field(
|
||||||
|
default=1.0, metadata={"help": "Temperature during generation."}
|
||||||
|
)
|
||||||
|
# Reward
|
||||||
|
group_reward_norm: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
reward_scaling: float = field(
|
||||||
|
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||||
|
)
|
||||||
|
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||||
|
reward_clip: float = field(
|
||||||
|
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||||
|
)
|
||||||
|
mask_no_eos_with_zero: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advantage Estimation
|
||||||
|
discount: float = field(
|
||||||
|
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||||
|
)
|
||||||
|
gae_lambda: float = field(
|
||||||
|
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||||
|
)
|
||||||
|
adv_norm: bool = field(
|
||||||
|
default=True, metadata={"help": "Enable advantage normalization"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# KL Control
|
||||||
|
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||||
|
|
||||||
|
# Asynchronous RL
|
||||||
|
recompute_logprob: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||||
|
)
|
||||||
|
use_decoupled_loss: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||||
|
)
|
||||||
|
behav_imp_weight_cap: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -25,10 +25,10 @@ class Scheduling:
|
||||||
cpu: int
|
cpu: int
|
||||||
gpu: int
|
gpu: int
|
||||||
mem: int
|
mem: int
|
||||||
nodelist: str = None
|
nodelist: Optional[str] = None
|
||||||
exclude: str = None
|
exclude: Optional[str] = None
|
||||||
partition: str = None
|
partition: Optional[str] = None
|
||||||
container_image: str = None
|
container_image: Optional[str] = None
|
||||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||||
# time utils from "https://slurm.schedmd.com/sbatch.html"
|
# time utils from "https://slurm.schedmd.com/sbatch.html"
|
||||||
time_limit: Optional[str] = None # see "--time" option for format
|
time_limit: Optional[str] = None # see "--time" option for format
|
||||||
|
@ -105,7 +105,7 @@ class TrainEngine(abc.ABC):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_: TensorDict,
|
input_: TensorDict,
|
||||||
output_seqlens: List[List[int]] | None = None,
|
output_seqlens: List[int] | None = None,
|
||||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
|
|
|
@ -71,7 +71,7 @@ class AllocationType(enum.Enum):
|
||||||
@dataclass
|
@dataclass
|
||||||
class AllocationMode:
|
class AllocationMode:
|
||||||
type_: AllocationType
|
type_: AllocationType
|
||||||
parallel_strat: None | Dict[str, Dict[str, int]]
|
parallel_strat: Dict[str, Dict[str, int]]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gen_tp_size(self) -> int:
|
def gen_tp_size(self) -> int:
|
||||||
|
@ -115,7 +115,7 @@ class AllocationMode:
|
||||||
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
|
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
|
def extract_parallelism_strategy(allocation_mode: str) -> Dict:
|
||||||
for x, y, z in itertools.permutations(["d", "t", "p"]):
|
for x, y, z in itertools.permutations(["d", "t", "p"]):
|
||||||
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
|
||||||
m = re.match(pattern, allocation_mode)
|
m = re.match(pattern, allocation_mode)
|
||||||
|
@ -130,29 +130,28 @@ class AllocationMode:
|
||||||
z: c,
|
z: c,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||||
)
|
)
|
||||||
m = pattern.match(allocation_mode)
|
m = pattern.match(allocation_mode)
|
||||||
if not m:
|
if not m:
|
||||||
return
|
raise ValueError(
|
||||||
|
f"Unknown how to resolve decoupled allocation: {allocation_mode}"
|
||||||
|
)
|
||||||
if m.group(1):
|
if m.group(1):
|
||||||
gen_alloc = m.group(1)
|
gen_alloc = m.group(1)
|
||||||
other_alloc = m.group(2)
|
other_alloc = m.group(2)
|
||||||
else:
|
else:
|
||||||
gen_alloc = m.group(4)
|
gen_alloc = m.group(4)
|
||||||
other_alloc = m.group(3)
|
other_alloc = m.group(3)
|
||||||
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
|
gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc)
|
||||||
if not gen_alloc:
|
other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc)
|
||||||
return
|
|
||||||
other_alloc = AllocationMode.extract_3d_alloc(
|
|
||||||
other_alloc
|
|
||||||
) or AllocationMode.extract_key_value_alloc(other_alloc)
|
|
||||||
if not other_alloc:
|
|
||||||
return
|
|
||||||
other_alloc.update({"gen": gen_alloc["*"]})
|
other_alloc.update({"gen": gen_alloc["*"]})
|
||||||
return other_alloc
|
return other_alloc
|
||||||
|
|
||||||
|
@ -171,7 +170,7 @@ class SaveLoadMeta:
|
||||||
path: str
|
path: str
|
||||||
weight_format: str
|
weight_format: str
|
||||||
with_optim: bool
|
with_optim: bool
|
||||||
tokenizer: PreTrainedTokenizerFast | None
|
tokenizer: Optional[PreTrainedTokenizerFast]
|
||||||
base_model_path: str | None
|
base_model_path: str | None
|
||||||
naive_distributed: bool = False
|
naive_distributed: bool = False
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
|
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||||
|
from arealite.utils.save_load import (
|
||||||
|
get_state_dict_from_repo_id_or_path,
|
||||||
|
is_existing_local_path,
|
||||||
|
)
|
||||||
|
from realhf.base import constants, logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("DeepSpeedAutoTPEngine")
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeedAutoTPEngine(BaseHFEngine):
|
||||||
|
def __init__(self, config: TrainEngineConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||||
|
"""Initialize distributed communication and model."""
|
||||||
|
assert (
|
||||||
|
addr is None
|
||||||
|
), "DeepSpeedAutoTPEngine does not support remote initialization."
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
self.create_process_group()
|
||||||
|
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE"))
|
||||||
|
deepspeed.init_distributed(
|
||||||
|
dist_backend="nccl",
|
||||||
|
world_size=world_size,
|
||||||
|
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
|
self.create_device_model()
|
||||||
|
# NOTE: the device context manager does not work here.
|
||||||
|
self.model = deepspeed.tp_model_init(
|
||||||
|
self.model,
|
||||||
|
tp_size=self.config.ds_auto_tp.autotp_size,
|
||||||
|
dtype=getattr(torch, self.config.dtype),
|
||||||
|
).to(self.device)
|
||||||
|
self.create_optimizer(ft_spec)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def _check_autotp(self):
|
||||||
|
tp_size = self.config.ds_auto_tp.autotp_size
|
||||||
|
config = self.model_config
|
||||||
|
num_attention_heads = config.num_attention_heads
|
||||||
|
num_key_value_heads = config.num_key_value_heads
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
return (
|
||||||
|
num_attention_heads % tp_size == 0
|
||||||
|
and num_key_value_heads % tp_size == 0
|
||||||
|
and hidden_size % tp_size == 0
|
||||||
|
and intermediate_size % tp_size == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, meta: SaveLoadMeta):
|
||||||
|
if meta.weight_format != "naive_distributed":
|
||||||
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("Model not initialized")
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
if rank == 0:
|
||||||
|
os.makedirs(meta.path, exist_ok=True)
|
||||||
|
self.model_config.save_pretrained(
|
||||||
|
meta.path,
|
||||||
|
)
|
||||||
|
if meta.tokenizer is not None:
|
||||||
|
meta.tokenizer.save_pretrained(
|
||||||
|
meta.path,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
state_dict = {
|
||||||
|
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
# Only support store parameters from model partitions respectively
|
||||||
|
gathered_state_dicts = None
|
||||||
|
if rank == 0:
|
||||||
|
gathered_state_dicts = [None for _ in range(world_size)]
|
||||||
|
dist.gather_object(
|
||||||
|
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
for i, state_dict in enumerate(gathered_state_dicts):
|
||||||
|
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
|
||||||
|
if meta.with_optim:
|
||||||
|
self.save_optimizer_state(meta.path)
|
||||||
|
|
||||||
|
def load(self, meta: SaveLoadMeta):
|
||||||
|
if meta.weight_format != "naive_distributed":
|
||||||
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
# Only support load full model parameters from huggingface
|
||||||
|
# and load model partition locally
|
||||||
|
if rank == 0 or is_existing_local_path(meta.path):
|
||||||
|
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
|
||||||
|
full_state = get_state_dict_from_repo_id_or_path(meta.path)
|
||||||
|
|
||||||
|
if hasattr(self.model, "module") and not hasattr(full_state):
|
||||||
|
full_state = {
|
||||||
|
f"module.{k}" if not k.startswith("module.") else k: v
|
||||||
|
for k, v in full_state.items()
|
||||||
|
}
|
||||||
|
self.model.load_state_dict(
|
||||||
|
full_state, strict=not self.model_config.tie_word_embeddings
|
||||||
|
)
|
||||||
|
if self.model_config.tie_word_embeddings:
|
||||||
|
self.model.tie_weights()
|
||||||
|
|
||||||
|
if meta.with_optim:
|
||||||
|
self.load_optimizer_state(meta.path)
|
||||||
|
|
||||||
|
def upload_weights(self, meta: WeightUpdateMeta):
|
||||||
|
raise ValueError(f"update weight not implemented {meta.type}")
|
|
@ -0,0 +1,360 @@
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
PretrainedConfig,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
get_constant_schedule_with_warmup,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
|
from arealite.api.engine_api import FinetuneSpec, TrainEngine
|
||||||
|
from arealite.utils.data import (
|
||||||
|
MicroBatchList,
|
||||||
|
amend_position_ids,
|
||||||
|
pack_tensor_dict,
|
||||||
|
pad_and_stack_tensors_along_first_dim,
|
||||||
|
pad_mb_list,
|
||||||
|
reorder_list,
|
||||||
|
split_padded_tensor_dict_into_mb_list,
|
||||||
|
unpack_sequence,
|
||||||
|
unsqueeze_mb_list,
|
||||||
|
)
|
||||||
|
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||||
|
from arealite.utils.model import disable_dropout_in_model
|
||||||
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
|
from realhf.base import constants, logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("Base HF Engine")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseHFEngine(TrainEngine):
|
||||||
|
def __init__(self, config: TrainEngineConfig):
|
||||||
|
self.config = config
|
||||||
|
self.optimizer_config = config.optimizer
|
||||||
|
|
||||||
|
self.model: torch.nn.Module
|
||||||
|
self.optimizer: torch.optim.Optimizer
|
||||||
|
self.tokenizer: PreTrainedTokenizerFast
|
||||||
|
# huggingface model config
|
||||||
|
self.model_config: PretrainedConfig
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
self.initialized = False
|
||||||
|
self.own_global_group = False
|
||||||
|
self._parallelism_group: dist.ProcessGroup
|
||||||
|
self.weight_update_group_initialized = False
|
||||||
|
|
||||||
|
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
assert self.model is not None
|
||||||
|
self.model.train(mode=mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parallelism_group(self) -> dist.ProcessGroup:
|
||||||
|
assert self.initialized
|
||||||
|
return self._parallelism_group
|
||||||
|
|
||||||
|
def create_process_group(self):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||||
|
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
||||||
|
)
|
||||||
|
self.own_global_group = True
|
||||||
|
self._parallelism_group = dist.new_group()
|
||||||
|
|
||||||
|
def create_device_model(self):
|
||||||
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
|
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.config.dtype)
|
||||||
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=self.config.path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||||
|
tik = time.perf_counter()
|
||||||
|
with torch.device("cuda"):
|
||||||
|
if self.config.init_from_scratch:
|
||||||
|
# initialize scratch model from config
|
||||||
|
# NOTE: VLM cannot directly load state dict using this
|
||||||
|
# random initialized model, so otherwise we call
|
||||||
|
# from_pretrained rather than loading weights into this random model.
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
self.model_config,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
attn_implementation=self.config.attn_impl,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=self.config.path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
attn_implementation=self.config.attn_impl,
|
||||||
|
)
|
||||||
|
if self.config.disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
|
if self.config.gradient_checkpointing:
|
||||||
|
model.gradient_checkpointing_enable(
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
|
)
|
||||||
|
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def create_optimizer(self, ft_spec: FinetuneSpec):
|
||||||
|
if self.optimizer_config is None:
|
||||||
|
return
|
||||||
|
assert self.model is not None
|
||||||
|
# Set up optimizer
|
||||||
|
tik = time.perf_counter()
|
||||||
|
assert (
|
||||||
|
self.optimizer_config.type == "adam"
|
||||||
|
), "Only AdamW optimizer is supported in this engine."
|
||||||
|
lr = self.optimizer_config.lr
|
||||||
|
weight_decay = self.optimizer_config.weight_decay
|
||||||
|
beta1 = self.optimizer_config.beta1
|
||||||
|
beta2 = self.optimizer_config.beta2
|
||||||
|
eps = self.optimizer_config.eps
|
||||||
|
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
betas=(beta1, beta2),
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
total_train_steps = ft_spec.total_train_steps
|
||||||
|
num_warmup_steps = int(
|
||||||
|
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps,
|
||||||
|
total_train_steps,
|
||||||
|
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||||
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps,
|
||||||
|
total_train_steps,
|
||||||
|
)
|
||||||
|
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||||
|
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||||
|
)
|
||||||
|
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
"""Destroy the engine and release GPU memory."""
|
||||||
|
del self.optimizer
|
||||||
|
del self.model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
dist.destroy_process_group(self.parallelism_group)
|
||||||
|
if self.own_global_group:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def save_optimizer_state(self, path: str):
|
||||||
|
# Save FSDP sharded state dict on each rank
|
||||||
|
assert self.optimizer is not None
|
||||||
|
assert dist.is_initialized()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
shard_path = os.path.join(
|
||||||
|
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||||
|
)
|
||||||
|
state_dict = self.optimizer.state_dict()
|
||||||
|
torch.save(state_dict, shard_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
def load_optimizer_state(self, path: str):
|
||||||
|
# Load FSDP sharded state dict
|
||||||
|
assert self.optimizer is not None
|
||||||
|
assert dist.is_initialized()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
shard_path = os.path.join(
|
||||||
|
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||||
|
)
|
||||||
|
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||||
|
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
def step_lr_scheduler(self):
|
||||||
|
assert self.lr_scheduler is not None
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||||
|
assert "attention_mask" in input_ and "input_ids" in input_
|
||||||
|
if isinstance(input_, dict):
|
||||||
|
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||||
|
input_ = amend_position_ids(input_)
|
||||||
|
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||||
|
logger.info(
|
||||||
|
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||||
|
)
|
||||||
|
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||||
|
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||||
|
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||||
|
# packed input to be of shape [1, total_seqlen].
|
||||||
|
mb_list = unsqueeze_mb_list(mb_list)
|
||||||
|
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||||
|
for mb in mb_list.mbs:
|
||||||
|
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||||
|
mb["use_cache"] = False
|
||||||
|
for mb in mb_list.padded_mbs:
|
||||||
|
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||||
|
mb["use_cache"] = False
|
||||||
|
return mb_list
|
||||||
|
|
||||||
|
def train_batch(
|
||||||
|
self,
|
||||||
|
input_: TensorDict,
|
||||||
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||||
|
loss_weight_fn: Callable[[TensorDict], float],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Train on a batch using gradient accumulation."""
|
||||||
|
input_ = input_.to(self.device)
|
||||||
|
assert self.optimizer is not None
|
||||||
|
assert self.optimizer_config is not None
|
||||||
|
assert self.lr_scheduler is not None
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
mb_list = self.prepare_mb_list(input_)
|
||||||
|
|
||||||
|
total_loss_weight = torch.tensor(
|
||||||
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
|
)
|
||||||
|
assert total_loss_weight != 0
|
||||||
|
dist.all_reduce(total_loss_weight)
|
||||||
|
|
||||||
|
# Process microbatches with gradient accumulation
|
||||||
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||||
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||||
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
loss = loss_fn(logits, mb_input)
|
||||||
|
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||||
|
|
||||||
|
# Scale loss for accumulation
|
||||||
|
# Revert gradient averaging across dp ranks
|
||||||
|
loss_scale *= self.world_size
|
||||||
|
|
||||||
|
loss *= loss_scale
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.optimizer_config.gradient_clipping,
|
||||||
|
norm_type=2.0,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
foreach=None,
|
||||||
|
)
|
||||||
|
if not torch.isfinite(grad_norm):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
update_successful = False
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
update_successful = True
|
||||||
|
|
||||||
|
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
|
# Optimizer step
|
||||||
|
self.optimizer.step()
|
||||||
|
return dict(
|
||||||
|
update_successful=float(update_successful),
|
||||||
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||||
|
lr=current_lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_batch(
|
||||||
|
self,
|
||||||
|
input_: TensorDict,
|
||||||
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||||
|
loss_weight_fn: Callable[[TensorDict], float],
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""Evaluate on a batch."""
|
||||||
|
input_ = input_.to(self.device)
|
||||||
|
mb_list = self.prepare_mb_list(input_)
|
||||||
|
total_loss_weight = torch.tensor(
|
||||||
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
|
)
|
||||||
|
assert total_loss_weight != 0
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
loss = loss_fn(logits, mb_input)
|
||||||
|
|
||||||
|
# Simple weight calculation (could be improved)
|
||||||
|
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||||
|
total_loss += loss.item() * loss_scale
|
||||||
|
total_weight += loss_scale
|
||||||
|
|
||||||
|
return torch.tensor(total_loss / total_weight)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_: TensorDict,
|
||||||
|
output_seqlens: List[int] | None = None,
|
||||||
|
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||||
|
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||||
|
) -> Any | None:
|
||||||
|
"""Forward pass with optional post-processing."""
|
||||||
|
input_ = input_.to(self.device)
|
||||||
|
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||||
|
mb_list = self.prepare_mb_list(input_)
|
||||||
|
|
||||||
|
if output_seqlens is None:
|
||||||
|
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
|
||||||
|
if post_hook:
|
||||||
|
result = post_hook(logits, mb_input)
|
||||||
|
results.append(result)
|
||||||
|
else:
|
||||||
|
results.append(logits)
|
||||||
|
|
||||||
|
res = aggregate_fn(results)
|
||||||
|
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||||
|
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||||
|
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||||
|
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
@ -1,42 +1,20 @@
|
||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torch.distributed.checkpoint.state_dict import (
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
get_model_state_dict,
|
get_model_state_dict,
|
||||||
)
|
)
|
||||||
from transformers import (
|
from transformers import PreTrainedTokenizerFast
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
get_constant_schedule_with_warmup,
|
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import (
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
FinetuneSpec,
|
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||||
SaveLoadMeta,
|
|
||||||
TrainEngine,
|
|
||||||
WeightUpdateMeta,
|
|
||||||
)
|
|
||||||
from arealite.utils.data import (
|
|
||||||
MicroBatchList,
|
|
||||||
amend_position_ids,
|
|
||||||
pack_tensor_dict,
|
|
||||||
pad_and_stack_tensors_along_first_dim,
|
|
||||||
pad_mb_list,
|
|
||||||
reorder_list,
|
|
||||||
split_padded_tensor_dict_into_mb_list,
|
|
||||||
unpack_sequence,
|
|
||||||
unsqueeze_mb_list,
|
|
||||||
)
|
|
||||||
from arealite.utils.fsdp import (
|
from arealite.utils.fsdp import (
|
||||||
CPUOffloadPolicy,
|
CPUOffloadPolicy,
|
||||||
MixedPrecisionPolicy,
|
MixedPrecisionPolicy,
|
||||||
|
@ -44,108 +22,35 @@ from arealite.utils.fsdp import (
|
||||||
create_fsdp_device_mesh,
|
create_fsdp_device_mesh,
|
||||||
fsdp2_clip_grad_norm_,
|
fsdp2_clip_grad_norm_,
|
||||||
fsdp2_load_full_state_dict,
|
fsdp2_load_full_state_dict,
|
||||||
get_cosine_schedule_with_warmup,
|
|
||||||
)
|
)
|
||||||
from arealite.utils.model import disable_dropout_in_model
|
|
||||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
from realhf.base import constants, logging, name_resolve, names, pkg_version
|
|
||||||
|
|
||||||
logger = logging.getLogger("FSDPEngine")
|
logger = logging.getLogger("FSDPEngine")
|
||||||
|
|
||||||
|
|
||||||
class FSDPEngine(TrainEngine):
|
class FSDPEngine(BaseHFEngine):
|
||||||
def __init__(self, config: TrainEngineConfig):
|
def __init__(self, config: TrainEngineConfig):
|
||||||
self.config = config
|
super().__init__(config)
|
||||||
self.optimizer_config = config.optimizer
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.optimizer = None
|
|
||||||
self.tokenizer = None
|
|
||||||
# huggingface model config
|
|
||||||
self.model_config = None
|
|
||||||
# FSDP options
|
# FSDP options
|
||||||
self.mixed_precision_policy = None
|
self.mixed_precision_policy = None
|
||||||
self.device_mesh = None
|
self.device_mesh = None
|
||||||
self.cpu_offload = None
|
self.cpu_offload = None
|
||||||
# initialization
|
|
||||||
self.initialized = False
|
|
||||||
self.own_global_group = False
|
|
||||||
self._parallelism_group = None
|
|
||||||
self.weight_update_group_initialized = False
|
|
||||||
|
|
||||||
# TODO: Handle the case when WORLD_SIZE is not set in launcher
|
|
||||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
assert self.model is not None
|
|
||||||
self.model.train(mode=mode)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parallelism_group(self) -> dist.ProcessGroup:
|
|
||||||
assert self.initialized
|
|
||||||
return self._parallelism_group
|
|
||||||
|
|
||||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||||
# Initialize distributed enviroments and load model.
|
# Initialize distributed enviroments and load model.
|
||||||
assert addr is None, "FSDPEngine does not support remote initialization."
|
assert addr is None, "FSDPEngine does not support remote initialization."
|
||||||
|
|
||||||
assert pkg_version.is_version_greater_or_equal(
|
assert pkg_version.is_version_greater_or_equal(
|
||||||
"torch", "2.4.0"
|
"torch", "2.4.0"
|
||||||
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
||||||
|
|
||||||
"""Initialize distributed communication and model."""
|
self.create_process_group()
|
||||||
if not dist.is_initialized():
|
self.create_device_model()
|
||||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
|
||||||
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
|
||||||
)
|
|
||||||
self.own_global_group = True
|
|
||||||
self._parallelism_group = dist.new_group()
|
|
||||||
|
|
||||||
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
|
|
||||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
|
||||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
|
||||||
|
|
||||||
dtype = getattr(torch, self.config.dtype)
|
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.config.path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
|
||||||
tik = time.perf_counter()
|
|
||||||
with torch.device("cuda"):
|
|
||||||
if self.config.init_from_scratch:
|
|
||||||
# initialize scratch model from config
|
|
||||||
# NOTE: VLM cannot directly load state dict using this
|
|
||||||
# random initialized model, so otherwise we call
|
|
||||||
# from_pretrained rather than loading weights into this random model.
|
|
||||||
model = AutoModelForCausalLM.from_config(
|
|
||||||
self.model_config,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
attn_implementation=self.config.attn_impl,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.config.path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
attn_implementation=self.config.attn_impl,
|
|
||||||
)
|
|
||||||
if self.config.disable_dropout:
|
|
||||||
disable_dropout_in_model(model)
|
|
||||||
if self.config.gradient_checkpointing:
|
|
||||||
model.gradient_checkpointing_enable(
|
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
||||||
)
|
|
||||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
|
||||||
|
|
||||||
|
# Wrap with FSDP2
|
||||||
# Simple auto wrap policy
|
# Simple auto wrap policy
|
||||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||||
param_dtype=dtype,
|
param_dtype=getattr(torch, self.config.dtype),
|
||||||
reduce_dtype=torch.float32,
|
reduce_dtype=torch.float32,
|
||||||
cast_forward_inputs=True,
|
cast_forward_inputs=True,
|
||||||
)
|
)
|
||||||
|
@ -154,82 +59,19 @@ class FSDPEngine(TrainEngine):
|
||||||
self.cpu_offload = (
|
self.cpu_offload = (
|
||||||
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
|
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
|
||||||
)
|
)
|
||||||
|
|
||||||
fsdp_kwargs = {
|
fsdp_kwargs = {
|
||||||
"mesh": self.device_mesh,
|
"mesh": self.device_mesh,
|
||||||
"mp_policy": self.mixed_precision_policy,
|
"mp_policy": self.mixed_precision_policy,
|
||||||
"offload_policy": self.cpu_offload,
|
"offload_policy": self.cpu_offload,
|
||||||
"reshard_after_forward": True,
|
"reshard_after_forward": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Wrap with FSDP2
|
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||||
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
||||||
self.model = model
|
|
||||||
|
|
||||||
# Set up optimizer
|
|
||||||
if self.optimizer_config is not None:
|
|
||||||
tik = time.perf_counter()
|
|
||||||
assert (
|
|
||||||
self.optimizer_config.type == "adam"
|
|
||||||
), "Only AdamW optimizer is supported in this engine."
|
|
||||||
lr = self.optimizer_config.lr
|
|
||||||
weight_decay = self.optimizer_config.weight_decay
|
|
||||||
beta1 = self.optimizer_config.beta1
|
|
||||||
beta2 = self.optimizer_config.beta2
|
|
||||||
eps = self.optimizer_config.eps
|
|
||||||
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
self.model.parameters(),
|
|
||||||
lr=lr,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
betas=(beta1, beta2),
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
total_train_steps = ft_spec.total_train_steps
|
|
||||||
num_warmup_steps = int(
|
|
||||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
total_train_steps,
|
|
||||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
|
||||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
total_train_steps,
|
|
||||||
)
|
|
||||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
|
||||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
|
||||||
)
|
|
||||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
|
||||||
|
|
||||||
|
self.create_optimizer(ft_spec)
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def destroy(self):
|
|
||||||
"""Destroy the engine and release GPU memory."""
|
|
||||||
self.model = None
|
|
||||||
self.optimizer = None
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
dist.destroy_process_group(self.parallelism_group)
|
|
||||||
if self.own_global_group:
|
|
||||||
dist.destroy_process_group()
|
|
||||||
self.initialized = False
|
|
||||||
|
|
||||||
def save(self, meta: SaveLoadMeta):
|
def save(self, meta: SaveLoadMeta):
|
||||||
if meta.weight_format == "hf":
|
if meta.weight_format == "hf":
|
||||||
self._save_model_to_hf(meta.path, meta.tokenizer)
|
self._save_model_to_hf(meta.path, meta.tokenizer)
|
||||||
|
@ -240,7 +82,7 @@ class FSDPEngine(TrainEngine):
|
||||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
|
||||||
if meta.with_optim:
|
if meta.with_optim:
|
||||||
self._save_optimizer_state(meta.path)
|
self.save_optimizer_state(meta.path)
|
||||||
|
|
||||||
def load(self, meta: SaveLoadMeta):
|
def load(self, meta: SaveLoadMeta):
|
||||||
if meta.weight_format == "hf":
|
if meta.weight_format == "hf":
|
||||||
|
@ -252,34 +94,10 @@ class FSDPEngine(TrainEngine):
|
||||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
|
||||||
if meta.with_optim:
|
if meta.with_optim:
|
||||||
self._load_optimizer_state(meta.path)
|
self.load_optimizer_state(meta.path)
|
||||||
|
|
||||||
def _save_optimizer_state(self, path: str):
|
|
||||||
# Save FSDP sharded state dict on each rank
|
|
||||||
assert self.optimizer is not None
|
|
||||||
assert dist.is_initialized()
|
|
||||||
rank = dist.get_rank()
|
|
||||||
shard_path = os.path.join(
|
|
||||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
|
||||||
)
|
|
||||||
state_dict = self.optimizer.state_dict()
|
|
||||||
torch.save(state_dict, shard_path)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
def _load_optimizer_state(self, path: str):
|
|
||||||
# Load FSDP sharded state dict
|
|
||||||
assert self.optimizer is not None
|
|
||||||
assert dist.is_initialized()
|
|
||||||
rank = dist.get_rank()
|
|
||||||
shard_path = os.path.join(
|
|
||||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
|
||||||
)
|
|
||||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
|
||||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
def _save_model_to_hf(
|
def _save_model_to_hf(
|
||||||
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
|
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
|
||||||
):
|
):
|
||||||
"""Save model in HuggingFace format."""
|
"""Save model in HuggingFace format."""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
@ -345,35 +163,11 @@ class FSDPEngine(TrainEngine):
|
||||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_lr_scheduler(self):
|
|
||||||
assert self.lr_scheduler is not None
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
|
||||||
assert "attention_mask" in input_ and "input_ids" in input_
|
|
||||||
if isinstance(input_, dict):
|
|
||||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
|
||||||
input_ = amend_position_ids(input_)
|
|
||||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
|
||||||
logger.info(
|
|
||||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
|
||||||
)
|
|
||||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
|
||||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
|
||||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
|
||||||
# packed input to be of shape [1, total_seqlen].
|
|
||||||
mb_list = unsqueeze_mb_list(mb_list)
|
|
||||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
|
||||||
for mb in mb_list.mbs:
|
|
||||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
|
||||||
mb["use_cache"] = False
|
|
||||||
return mb_list
|
|
||||||
|
|
||||||
def train_batch(
|
def train_batch(
|
||||||
self,
|
self,
|
||||||
input_: TensorDict,
|
input_: TensorDict,
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
loss_weight_fn: Callable[[TensorDict], float],
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Train on a batch using gradient accumulation."""
|
"""Train on a batch using gradient accumulation."""
|
||||||
input_ = input_.to(self.device)
|
input_ = input_.to(self.device)
|
||||||
|
@ -382,7 +176,7 @@ class FSDPEngine(TrainEngine):
|
||||||
assert self.lr_scheduler is not None
|
assert self.lr_scheduler is not None
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
mb_list = self._prepare_mb_list(input_)
|
mb_list = self.prepare_mb_list(input_)
|
||||||
|
|
||||||
total_loss_weight = torch.tensor(
|
total_loss_weight = torch.tensor(
|
||||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
|
@ -394,7 +188,6 @@ class FSDPEngine(TrainEngine):
|
||||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||||
):
|
):
|
||||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
|
||||||
outputs = self.model(**padded_mb_input)
|
outputs = self.model(**padded_mb_input)
|
||||||
|
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
|
@ -409,6 +202,7 @@ class FSDPEngine(TrainEngine):
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
# NOTE: grad norm clip function is different
|
||||||
grad_norm = fsdp2_clip_grad_norm_(
|
grad_norm = fsdp2_clip_grad_norm_(
|
||||||
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
||||||
)
|
)
|
||||||
|
@ -427,72 +221,3 @@ class FSDPEngine(TrainEngine):
|
||||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||||
lr=current_lr,
|
lr=current_lr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def eval_batch(
|
|
||||||
self,
|
|
||||||
input_: TensorDict,
|
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
"""Evaluate on a batch."""
|
|
||||||
input_ = input_.to(self.device)
|
|
||||||
mb_list = self._prepare_mb_list(input_)
|
|
||||||
total_loss_weight = torch.tensor(
|
|
||||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
|
||||||
)
|
|
||||||
assert total_loss_weight != 0
|
|
||||||
|
|
||||||
total_loss = 0.0
|
|
||||||
total_weight = 0.0
|
|
||||||
|
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
|
||||||
):
|
|
||||||
outputs = self.model(**padded_mb_input)
|
|
||||||
logits = outputs.logits.squeeze(0)
|
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
||||||
loss = loss_fn(logits, mb_input)
|
|
||||||
|
|
||||||
# Simple weight calculation (could be improved)
|
|
||||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
|
||||||
total_loss += loss.item() * loss_scale
|
|
||||||
total_weight += loss_scale
|
|
||||||
|
|
||||||
return torch.tensor(total_loss / total_weight)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_: TensorDict,
|
|
||||||
output_seqlens: List[int] | None = None,
|
|
||||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
|
||||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Forward pass with optional post-processing."""
|
|
||||||
input_ = input_.to(self.device)
|
|
||||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
|
||||||
mb_list = self._prepare_mb_list(input_)
|
|
||||||
|
|
||||||
if output_seqlens is None:
|
|
||||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
|
||||||
):
|
|
||||||
outputs = self.model(**padded_mb_input)
|
|
||||||
logits = outputs.logits.squeeze(0)
|
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
||||||
|
|
||||||
if post_hook:
|
|
||||||
result = post_hook(logits, mb_input)
|
|
||||||
results.append(result)
|
|
||||||
else:
|
|
||||||
results.append(logits)
|
|
||||||
|
|
||||||
res = aggregate_fn(results)
|
|
||||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
|
||||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
|
||||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
|
||||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
||||||
|
|
|
@ -1,467 +0,0 @@
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import transformers
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
from tensordict import TensorDict
|
|
||||||
from transformers import (
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
get_constant_schedule_with_warmup,
|
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
|
||||||
from arealite.api.engine_api import (
|
|
||||||
FinetuneSpec,
|
|
||||||
SaveLoadMeta,
|
|
||||||
TrainEngine,
|
|
||||||
WeightUpdateMeta,
|
|
||||||
)
|
|
||||||
from arealite.utils.data import (
|
|
||||||
MicroBatchList,
|
|
||||||
amend_position_ids,
|
|
||||||
pack_tensor_dict,
|
|
||||||
pad_and_stack_tensors_along_first_dim,
|
|
||||||
pad_mb_list,
|
|
||||||
reorder_list,
|
|
||||||
split_packed_tensor_dict_into_mb_list,
|
|
||||||
unpack_sequence,
|
|
||||||
unsqueeze_mb_list,
|
|
||||||
)
|
|
||||||
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
|
||||||
from arealite.utils.save_load import (
|
|
||||||
get_state_dict_from_repo_id_or_path,
|
|
||||||
is_existing_local_path,
|
|
||||||
)
|
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
|
||||||
from realhf.base import logging, name_resolve, names
|
|
||||||
|
|
||||||
logger = logging.getLogger("HFEngine")
|
|
||||||
|
|
||||||
|
|
||||||
class HFEngine(TrainEngine):
|
|
||||||
def __init__(self, config: TrainEngineConfig):
|
|
||||||
self.config = config
|
|
||||||
self.optimizer_config = config.optimizer
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.optimizer = None
|
|
||||||
self.tokenizer = None
|
|
||||||
# huggingface model config
|
|
||||||
self.model_config = None
|
|
||||||
# initialization
|
|
||||||
self.initialized = False
|
|
||||||
self.weight_update_group_initialized = False
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
assert self.model is not None
|
|
||||||
self.model.train(mode=mode)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
|
||||||
"""Initialize distributed communication and model."""
|
|
||||||
assert addr is None, "HFEngine does not support remote initialization."
|
|
||||||
|
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 0))
|
|
||||||
if not dist.is_initialized() and world_size > 1:
|
|
||||||
try:
|
|
||||||
import deepspeed
|
|
||||||
except ImportError:
|
|
||||||
print(
|
|
||||||
"Warning: deepspeed is not installed. Some functionality may be disabled."
|
|
||||||
)
|
|
||||||
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
|
|
||||||
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
self.device = torch.device(f"cuda:{local_rank}")
|
|
||||||
|
|
||||||
dtype = getattr(torch, self.config.dtype)
|
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.config.path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_config(
|
|
||||||
self.model_config,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
attn_implementation=self.config.attn_impl,
|
|
||||||
).to(f"cuda:{local_rank}")
|
|
||||||
|
|
||||||
if not self.config.init_from_scratch:
|
|
||||||
# Load model from a initial checkpoint path,
|
|
||||||
# which should only be a huggingface checkpoint.
|
|
||||||
load_meta = SaveLoadMeta(
|
|
||||||
path=self.config.path,
|
|
||||||
weight_format="hf",
|
|
||||||
with_optim=False,
|
|
||||||
tokenizer=None,
|
|
||||||
base_model_path=self.config.path,
|
|
||||||
naive_distributed=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.load(load_meta)
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
if self._check_autotp():
|
|
||||||
self.model = deepspeed.tp_model_init(
|
|
||||||
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
|
|
||||||
|
|
||||||
# Set up optimizer
|
|
||||||
if self.optimizer_config is not None:
|
|
||||||
assert (
|
|
||||||
self.optimizer_config.type == "adam"
|
|
||||||
), "Only AdamW optimizer is supported in this engine."
|
|
||||||
lr = self.optimizer_config.lr
|
|
||||||
weight_decay = self.optimizer_config.weight_decay
|
|
||||||
beta1 = self.optimizer_config.beta1
|
|
||||||
beta2 = self.optimizer_config.beta2
|
|
||||||
eps = self.optimizer_config.eps
|
|
||||||
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
self.model.parameters(),
|
|
||||||
lr=lr,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
betas=(beta1, beta2),
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
total_train_steps = ft_spec.total_train_steps
|
|
||||||
num_warmup_steps = int(
|
|
||||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
total_train_steps,
|
|
||||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
|
||||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
total_train_steps,
|
|
||||||
)
|
|
||||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
|
||||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
|
||||||
self.optimizer,
|
|
||||||
num_warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def _check_autotp(self):
|
|
||||||
tp_size = self.config.hf.autotp_size
|
|
||||||
config = self.model_config
|
|
||||||
num_attention_heads = config.num_attention_heads
|
|
||||||
num_key_value_heads = config.num_key_value_heads
|
|
||||||
hidden_size = config.hidden_size
|
|
||||||
intermediate_size = config.intermediate_size
|
|
||||||
|
|
||||||
return (
|
|
||||||
num_attention_heads % tp_size == 0
|
|
||||||
and num_key_value_heads % tp_size == 0
|
|
||||||
and hidden_size % tp_size == 0
|
|
||||||
and intermediate_size % tp_size == 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def destroy(self):
|
|
||||||
"""Destroy the engine and release GPU memory."""
|
|
||||||
self.model = None
|
|
||||||
self.optimizer = None
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
self.initialized = False
|
|
||||||
|
|
||||||
def save(self, meta: SaveLoadMeta):
|
|
||||||
if meta.weight_format == "hf":
|
|
||||||
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
|
|
||||||
elif meta.weight_format == "dcp":
|
|
||||||
# TODO: implement DCP save/load for HF
|
|
||||||
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
||||||
|
|
||||||
if meta.with_optim:
|
|
||||||
self._save_optimizer_state(meta.path)
|
|
||||||
|
|
||||||
def load(self, meta: SaveLoadMeta):
|
|
||||||
if meta.weight_format == "hf":
|
|
||||||
self._load_model_from_hf(meta.path, meta.naive_distributed)
|
|
||||||
elif meta.weight_format == "dcp":
|
|
||||||
# TODO: implement DCP save/load for HF
|
|
||||||
raise NotImplementedError("DCP format loading is not implemented yet. ")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
||||||
|
|
||||||
if meta.with_optim:
|
|
||||||
self._load_optimizer_state(meta.path)
|
|
||||||
|
|
||||||
def _save_optimizer_state(self, path: str):
|
|
||||||
assert self.optimizer is not None
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
|
|
||||||
|
|
||||||
def _load_optimizer_state(self, path: str):
|
|
||||||
assert self.optimizer is not None
|
|
||||||
path = os.path.join(path, "optim.pt")
|
|
||||||
optimizer_state_dict = torch.load(path, weights_only=False)
|
|
||||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
|
||||||
|
|
||||||
def _save_model_to_hf(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
|
|
||||||
naive_distributed: bool,
|
|
||||||
):
|
|
||||||
"""Save model in HuggingFace format."""
|
|
||||||
if self.model is None:
|
|
||||||
raise RuntimeError("Model not initialized")
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
if rank == 0:
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
self.model_config.save_pretrained(path)
|
|
||||||
if tokenizer is not None:
|
|
||||||
tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
state_dict = self.model.state_dict()
|
|
||||||
|
|
||||||
if hasattr(self.model, "module"):
|
|
||||||
state_dict = {
|
|
||||||
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
|
||||||
|
|
||||||
if world_size > 1 and naive_distributed:
|
|
||||||
# Only support store parameters from model partitions respectively
|
|
||||||
gathered_state_dicts = None
|
|
||||||
if rank == 0:
|
|
||||||
gathered_state_dicts = [None for _ in range(world_size)]
|
|
||||||
|
|
||||||
dist.gather_object(
|
|
||||||
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
for i, state_dict in enumerate(gathered_state_dicts):
|
|
||||||
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
|
|
||||||
else:
|
|
||||||
self.model.save_pretrained(path, state_dict=state_dict)
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
def _load_model_from_hf(self, path: str, naive_distributed: bool):
|
|
||||||
"""Load model from HuggingFace format."""
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
# Only support load full model parameters from huggingface
|
|
||||||
# and load model partition locally
|
|
||||||
if rank == 0 or is_existing_local_path(path):
|
|
||||||
if naive_distributed:
|
|
||||||
path = f"{path}/rank_{rank:02d}_model.safetensors"
|
|
||||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
|
||||||
|
|
||||||
if hasattr(self.model, "module") and not hasattr(full_state):
|
|
||||||
full_state = {
|
|
||||||
f"module.{k}" if not k.startswith("module.") else k: v
|
|
||||||
for k, v in full_state.items()
|
|
||||||
}
|
|
||||||
self.model.load_state_dict(
|
|
||||||
full_state, strict=not self.model_config.tie_word_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model_config.tie_word_embeddings:
|
|
||||||
self.model.tie_weights()
|
|
||||||
|
|
||||||
def upload_weights(self, meta: WeightUpdateMeta):
|
|
||||||
if meta.type == "nccl":
|
|
||||||
if not self.weight_update_group_initialized:
|
|
||||||
self._init_distributed_weight_update(meta)
|
|
||||||
self._update_weights_from_distributed()
|
|
||||||
elif meta.type == "disk":
|
|
||||||
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
|
|
||||||
update_name = names.update_weights_from_disk(
|
|
||||||
self.config.experiment_name,
|
|
||||||
self.config.trial_name,
|
|
||||||
meta.model_version,
|
|
||||||
)
|
|
||||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
|
||||||
|
|
||||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Distributed weight update is not implemented for HFEngine yet. "
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_weights_from_distributed(self):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Distributed weight update is not implemented for HFEngine yet. "
|
|
||||||
)
|
|
||||||
|
|
||||||
def step_lr_scheduler(self):
|
|
||||||
assert self.lr_scheduler is not None
|
|
||||||
return self.lr_scheduler.step()
|
|
||||||
|
|
||||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
|
||||||
assert "attention_mask" in input_ and "input_ids" in input_
|
|
||||||
if isinstance(input_, dict):
|
|
||||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
|
||||||
input_ = amend_position_ids(input_)
|
|
||||||
packed_input = pack_tensor_dict(input_)
|
|
||||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
|
||||||
packed_input,
|
|
||||||
self.config.mb_spec,
|
|
||||||
)
|
|
||||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
|
||||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
|
||||||
# packed input to be of shape [1, total_seqlen].
|
|
||||||
mb_list = unsqueeze_mb_list(mb_list)
|
|
||||||
return mb_list
|
|
||||||
|
|
||||||
def train_batch(
|
|
||||||
self,
|
|
||||||
input_: TensorDict,
|
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""Train on a batch using gradient accumulation."""
|
|
||||||
input_ = input_.to(self.device)
|
|
||||||
assert self.optimizer is not None
|
|
||||||
assert self.optimizer_config is not None
|
|
||||||
assert self.lr_scheduler is not None
|
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
mb_list = self._prepare_mb_list(input_)
|
|
||||||
|
|
||||||
total_loss_weight = torch.tensor(
|
|
||||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
|
||||||
)
|
|
||||||
assert total_loss_weight != 0
|
|
||||||
|
|
||||||
# Process microbatches with gradient accumulation
|
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
|
||||||
):
|
|
||||||
outputs = self.model(**padded_mb_input)
|
|
||||||
|
|
||||||
logits = outputs.logits.squeeze(0)
|
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
||||||
loss = loss_fn(logits, mb_input)
|
|
||||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
|
||||||
|
|
||||||
loss *= loss_scale
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.model.parameters(),
|
|
||||||
self.optimizer_config.gradient_clipping,
|
|
||||||
norm_type=2.0,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
foreach=None,
|
|
||||||
)
|
|
||||||
if not torch.isfinite(grad_norm):
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
update_successful = False
|
|
||||||
else:
|
|
||||||
self.optimizer.step()
|
|
||||||
update_successful = True
|
|
||||||
|
|
||||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
|
||||||
# Optimizer step
|
|
||||||
self.optimizer.step()
|
|
||||||
return dict(
|
|
||||||
update_successful=float(update_successful),
|
|
||||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
|
||||||
lr=current_lr,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def eval_batch(
|
|
||||||
self,
|
|
||||||
input_: TensorDict,
|
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
"""Evaluate on a batch."""
|
|
||||||
mb_list = self._prepare_mb_list(input_)
|
|
||||||
total_loss_weight = torch.tensor(
|
|
||||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
|
||||||
)
|
|
||||||
assert total_loss_weight != 0
|
|
||||||
|
|
||||||
total_loss = 0.0
|
|
||||||
total_weight = 0.0
|
|
||||||
|
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
|
||||||
):
|
|
||||||
outputs = self.model(**padded_mb_input)
|
|
||||||
logits = outputs.logits.squeeze(0)
|
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
||||||
loss = loss_fn(logits, mb_input)
|
|
||||||
|
|
||||||
# Simple weight calculation (could be improved)
|
|
||||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
|
||||||
total_loss += loss.item() * loss_scale
|
|
||||||
total_weight += loss_scale
|
|
||||||
|
|
||||||
return torch.tensor(total_loss / total_weight)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_: TensorDict,
|
|
||||||
output_seqlens: List[int] | None = None,
|
|
||||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
|
||||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Forward pass with optional post-processing."""
|
|
||||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
|
||||||
mb_list = self._prepare_mb_list(input_)
|
|
||||||
|
|
||||||
if output_seqlens is None:
|
|
||||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
|
||||||
):
|
|
||||||
outputs = self.model(**padded_mb_input)
|
|
||||||
logits = outputs.logits.squeeze(0)
|
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
||||||
|
|
||||||
if post_hook:
|
|
||||||
result = post_hook(logits, mb_input)
|
|
||||||
results.append(result)
|
|
||||||
else:
|
|
||||||
results.append(logits)
|
|
||||||
|
|
||||||
res = aggregate_fn(results)
|
|
||||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
|
||||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
|
||||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
|
||||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
|
@ -114,7 +114,9 @@ class PPOActor:
|
||||||
values = torch.zeros_like(rewards)
|
values = torch.zeros_like(rewards)
|
||||||
else:
|
else:
|
||||||
values = data["values"]
|
values = data["values"]
|
||||||
advantages_reversed = []
|
advantages_reversed = [
|
||||||
|
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
||||||
|
]
|
||||||
lastgaelam = 0
|
lastgaelam = 0
|
||||||
for t in reversed(range(max_seqlen - 1)):
|
for t in reversed(range(max_seqlen - 1)):
|
||||||
nextvalues = values[:, t + 1]
|
nextvalues = values[:, t + 1]
|
||||||
|
@ -123,9 +125,6 @@ class PPOActor:
|
||||||
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
||||||
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
||||||
advantages_reversed.append(lastgaelam)
|
advantages_reversed.append(lastgaelam)
|
||||||
advantages_reversed.append(
|
|
||||||
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
|
||||||
)
|
|
||||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||||
|
|
||||||
# Optionally perform advantage normalization.
|
# Optionally perform advantage normalization.
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
@ -44,9 +42,7 @@ class FSDPLMEngine(FSDPEngine):
|
||||||
return self.lm_engine.evaluate_lm(data)
|
return self.lm_engine.evaluate_lm(data)
|
||||||
|
|
||||||
|
|
||||||
def compute_packed_sft_loss(
|
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
||||||
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||||
loss_mask = input_["loss_mask"].bool()
|
loss_mask = input_["loss_mask"].bool()
|
||||||
|
|
|
@ -7,10 +7,12 @@ import traceback
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Empty, Full, Queue
|
from queue import Empty, Full, Queue
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import uvloop
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
|
@ -23,8 +25,8 @@ from arealite.api.io_struct import (
|
||||||
RolloutStat,
|
RolloutStat,
|
||||||
WeightUpdateMeta,
|
WeightUpdateMeta,
|
||||||
)
|
)
|
||||||
from arealite.utils.http import arequest_with_retry
|
from arealite.utils.data import concat_padded_tensors
|
||||||
from arealite.utils.padding import concat_padded_tensors
|
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -37,7 +39,7 @@ if pkg_version.is_available("sglang"):
|
||||||
else:
|
else:
|
||||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||||
|
|
||||||
ROLLOUT_POLL_WAIT_TIME = 0.1
|
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||||
RID_CACHE_SIZE = 128
|
RID_CACHE_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,6 +100,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
|
|
||||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
self.executor = ProcessPoolExecutor(max_workers=1)
|
self.executor = ProcessPoolExecutor(max_workers=1)
|
||||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||||
self.rollout_thread.start()
|
self.rollout_thread.start()
|
||||||
|
@ -118,32 +121,39 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
def _rollout_thread(self):
|
def _rollout_thread(self):
|
||||||
"""Thread that runs the rollout loop."""
|
"""Thread that runs the rollout loop."""
|
||||||
try:
|
try:
|
||||||
asyncio.run(self._rollout_thread_async())
|
uvloop.run(self._rollout_thread_async())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _rollout_thread_async(self):
|
async def _rollout_thread_async(self):
|
||||||
pending_data = []
|
|
||||||
rollout_tasks = self.rollout_tasks
|
rollout_tasks = self.rollout_tasks
|
||||||
rid = 0
|
rid = 0
|
||||||
|
|
||||||
|
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
|
||||||
|
self.session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self.config.request_timeout,
|
||||||
|
sock_connect=self.config.request_timeout,
|
||||||
|
connect=self.config.request_timeout,
|
||||||
|
),
|
||||||
|
read_bufsize=1024 * 1024 * 10,
|
||||||
|
connector=get_default_connector(),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not self.exiting.is_set():
|
while not self.exiting.is_set():
|
||||||
# Load next data from controller
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
data, workflow = self.input_queue.get_nowait()
|
|
||||||
logger.debug(f"Get data from puller: {data}")
|
|
||||||
pending_data.append(data)
|
|
||||||
except Empty:
|
|
||||||
logger.debug(f"No data from puller stream.")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check capacity
|
# Check capacity
|
||||||
capacity = self.get_capacity()
|
capacity = self.get_capacity()
|
||||||
# Create new rollout task
|
# Create new rollout task
|
||||||
while capacity > 0 and pending_data and not self.paused.is_set():
|
while (
|
||||||
|
capacity > 0
|
||||||
|
and not self.paused.is_set()
|
||||||
|
and self.input_queue.qsize() > 0
|
||||||
|
):
|
||||||
|
data, workflow = self.input_queue.get_nowait()
|
||||||
|
logger.debug(f"Get data from puller: {data}")
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
|
workflow.arun_episode(self, data), name=str(rid)
|
||||||
)
|
)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
rollout_tasks[str(rid)] = task
|
rollout_tasks[str(rid)] = task
|
||||||
|
@ -158,7 +168,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
)
|
)
|
||||||
capacity -= 1
|
capacity -= 1
|
||||||
rid += 1
|
rid += 1
|
||||||
|
|
||||||
# Wait for rollout completion
|
# Wait for rollout completion
|
||||||
with self.lock:
|
with self.lock:
|
||||||
tasks = list(rollout_tasks.values())
|
tasks = list(rollout_tasks.values())
|
||||||
|
@ -169,11 +178,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
if not done:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
# Collect done results
|
# Collect done results
|
||||||
for task in done:
|
for task in done:
|
||||||
traj = await task
|
traj = await task
|
||||||
|
@ -199,6 +203,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
f"running: {self.rollout_stat.running}, "
|
f"running: {self.rollout_stat.running}, "
|
||||||
f"accepted: {self.rollout_stat.accepted}."
|
f"accepted: {self.rollout_stat.accepted}."
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
|
@ -213,10 +218,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def choose_server(self) -> str:
|
def choose_server(self) -> str:
|
||||||
if self.config.schedule_policy == "round_robin":
|
with self.lock:
|
||||||
server = self.addresses[self.server_idx]
|
if self.config.schedule_policy == "round_robin":
|
||||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
server = self.addresses[self.server_idx]
|
||||||
return server
|
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||||
|
return server
|
||||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||||
|
|
||||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||||
|
@ -253,7 +259,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
accumulated_versions = []
|
accumulated_versions = []
|
||||||
|
|
||||||
# Deal with rollout interruption
|
# Deal with rollout interruption
|
||||||
completions = ""
|
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
if req.rid in self.rid_to_address:
|
if req.rid in self.rid_to_address:
|
||||||
|
@ -273,11 +278,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
):
|
):
|
||||||
# loop until the generation is complete
|
# loop until the generation is complete
|
||||||
result = await arequest_with_retry(
|
result = await arequest_with_retry(
|
||||||
addr=self.choose_server(),
|
session=self.session,
|
||||||
|
addr=server_addr,
|
||||||
endpoint="/generate",
|
endpoint="/generate",
|
||||||
payload=payload,
|
payload=payload,
|
||||||
method="POST",
|
method="POST",
|
||||||
max_retries=3,
|
max_retries=self.config.request_retries,
|
||||||
timeout=self.config.request_timeout,
|
timeout=self.config.request_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -297,10 +303,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
stop_reason = finish_reason["type"]
|
stop_reason = finish_reason["type"]
|
||||||
|
|
||||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||||
sample_params["max_new_tokens"] = min(
|
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||||
sample_params["max_new_tokens"],
|
|
||||||
gconfig.max_new_tokens - len(output_tokens),
|
|
||||||
)
|
|
||||||
|
|
||||||
latency = time.perf_counter() - start_time
|
latency = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
@ -408,18 +411,24 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
|
|
||||||
def prepare_batch(
|
def prepare_batch(
|
||||||
self,
|
self,
|
||||||
data_generator: Iterator,
|
|
||||||
dataloader: StatefulDataLoader,
|
dataloader: StatefulDataLoader,
|
||||||
workflow: "RolloutWorkflow",
|
workflow: "RolloutWorkflow",
|
||||||
):
|
):
|
||||||
|
if not hasattr(self, "data_generator"):
|
||||||
|
self.data_generator = iter(dataloader)
|
||||||
assert dataloader.batch_size is not None
|
assert dataloader.batch_size is not None
|
||||||
while True:
|
while True:
|
||||||
if self.get_capacity() + dataloader.batch_size > 0:
|
# Submit at least two batches to allow maximum overlap
|
||||||
|
if (
|
||||||
|
self.get_capacity() + dataloader.batch_size > 0
|
||||||
|
and self.input_queue.qsize() + dataloader.batch_size
|
||||||
|
< self.input_queue.maxsize
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
data = next(data_generator)
|
data = next(self.data_generator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
data_generator = iter(dataloader)
|
self.data_generator = iter(dataloader)
|
||||||
data = next(data_generator)
|
data = next(self.data_generator)
|
||||||
for item in data:
|
for item in data:
|
||||||
self.submit(item, workflow=workflow)
|
self.submit(item, workflow=workflow)
|
||||||
try:
|
try:
|
||||||
|
@ -435,10 +444,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
|
|
||||||
|
|
||||||
async def aupdate_weights_from_disk(
|
async def aupdate_weights_from_disk(
|
||||||
addr, path: str, request_retries: int, request_timeout: float
|
session, addr, path: str, request_retries: int, request_timeout: float
|
||||||
):
|
):
|
||||||
|
tik = time.time()
|
||||||
res = await arequest_with_retry(
|
res = await arequest_with_retry(
|
||||||
addr=addr,
|
addr=addr,
|
||||||
|
session=session,
|
||||||
endpoint="/update_weights_from_disk",
|
endpoint="/update_weights_from_disk",
|
||||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||||
method="POST",
|
method="POST",
|
||||||
|
@ -472,9 +483,19 @@ def update_weights_from_disk(
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
||||||
)
|
)
|
||||||
|
session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=request_timeout,
|
||||||
|
sock_connect=request_timeout,
|
||||||
|
connect=request_timeout,
|
||||||
|
),
|
||||||
|
read_bufsize=1024 * 1024 * 10,
|
||||||
|
connector=get_default_connector(),
|
||||||
|
)
|
||||||
jobs = [
|
jobs = [
|
||||||
aupdate_weights_from_disk(
|
aupdate_weights_from_disk(
|
||||||
addr,
|
session=session,
|
||||||
|
addr=addr,
|
||||||
path=path,
|
path=path,
|
||||||
request_retries=request_retries,
|
request_retries=request_retries,
|
||||||
request_timeout=request_timeout,
|
request_timeout=request_timeout,
|
||||||
|
@ -482,8 +503,9 @@ def update_weights_from_disk(
|
||||||
for addr in addresses
|
for addr in addresses
|
||||||
]
|
]
|
||||||
await asyncio.gather(*jobs)
|
await asyncio.gather(*jobs)
|
||||||
|
await session.close()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
return asyncio.run(_fn())
|
return uvloop.run(_fn())
|
||||||
|
|
|
@ -5,6 +5,7 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
@ -28,6 +29,17 @@ HOST = network.gethostip()
|
||||||
RUN_SERVER_TIMEOUT = 180
|
RUN_SERVER_TIMEOUT = 180
|
||||||
|
|
||||||
|
|
||||||
|
def check_server_health(base_url):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{base_url}/metrics",
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
return response.status_code == 200
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def sglang_server():
|
def sglang_server():
|
||||||
from realhf.base import seeding
|
from realhf.base import seeding
|
||||||
|
|
|
@ -67,7 +67,6 @@ async def test_local_sglang_generate():
|
||||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||||
)
|
)
|
||||||
resp = await engine.agenerate(req)
|
resp = await engine.agenerate(req)
|
||||||
print(resp.completions)
|
|
||||||
|
|
||||||
assert isinstance(resp, LLMResponse)
|
assert isinstance(resp, LLMResponse)
|
||||||
assert resp.input_tokens == req.input_ids
|
assert resp.input_tokens == req.input_ids
|
||||||
|
@ -76,9 +75,6 @@ async def test_local_sglang_generate():
|
||||||
== len(resp.output_tokens)
|
== len(resp.output_tokens)
|
||||||
== len(resp.output_versions)
|
== len(resp.output_versions)
|
||||||
)
|
)
|
||||||
assert isinstance(resp.completions, str)
|
|
||||||
|
|
||||||
time.sleep(5)
|
|
||||||
engine.destroy()
|
engine.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ from transformers import AutoTokenizer
|
||||||
|
|
||||||
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
|
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
|
||||||
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
|
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
|
||||||
|
|
||||||
VOCAB_SIZE = 100
|
VOCAB_SIZE = 100
|
||||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||||
|
@ -53,10 +52,10 @@ def mock_input(
|
||||||
|
|
||||||
|
|
||||||
def get_engine(engine_type: str, model_path: str):
|
def get_engine(engine_type: str, model_path: str):
|
||||||
|
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
from arealite.engine.hf_engine import HFEngine
|
|
||||||
|
|
||||||
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type]
|
engine_cls = {"auto_tp": DeepSpeedAutoTPEngine, "fsdp": FSDPEngine}[engine_type]
|
||||||
|
|
||||||
engine_config = TrainEngineConfig(
|
engine_config = TrainEngineConfig(
|
||||||
experiment_name=f"test-{engine_type}-engine",
|
experiment_name=f"test-{engine_type}-engine",
|
||||||
|
@ -75,7 +74,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
||||||
return torch.mean(logits)
|
return torch.mean(logits)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["fsdp", "hf"])
|
@pytest.fixture(scope="module", params=["fsdp", "auto_tp"])
|
||||||
def engine(request):
|
def engine(request):
|
||||||
os.environ.update(
|
os.environ.update(
|
||||||
{
|
{
|
||||||
|
@ -136,6 +135,12 @@ def test_train_batch(engine, mock_input):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
||||||
|
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
|
||||||
|
|
||||||
|
if isinstance(engine, DeepSpeedAutoTPEngine):
|
||||||
|
print("AutoTP engine does not support HF save/load for now.")
|
||||||
|
return
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
path = tmp_path_factory.mktemp("hf_engine_test")
|
path = tmp_path_factory.mktemp("hf_engine_test")
|
||||||
save_load_meta = SaveLoadMeta(
|
save_load_meta = SaveLoadMeta(
|
|
@ -92,9 +92,7 @@ def unpad_input(
|
||||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
|
||||||
)
|
|
||||||
return (
|
return (
|
||||||
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
||||||
indices,
|
indices,
|
||||||
|
@ -116,16 +114,12 @@ def concat_padded_tensors(
|
||||||
if not tensor_dicts:
|
if not tensor_dicts:
|
||||||
return TensorDict()
|
return TensorDict()
|
||||||
|
|
||||||
# Find max sequence length across all dictionaries
|
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
|
||||||
lens = []
|
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
|
||||||
for tensor_dict in tensor_dicts:
|
|
||||||
for key, tensor in tensor_dict.items():
|
|
||||||
if key != "attention_mask" and len(tensor.shape) == 2:
|
|
||||||
lens.append(tensor.shape[1])
|
|
||||||
break
|
|
||||||
max_length = max(lens)
|
|
||||||
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
|
|
||||||
|
|
||||||
|
# Find max sequence length across all dictionaries
|
||||||
|
assert all("attention_mask" in td for td in tensor_dicts)
|
||||||
|
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
||||||
result = {}
|
result = {}
|
||||||
# Process each key
|
# Process each key
|
||||||
for key in tensor_dicts[0].keys():
|
for key in tensor_dicts[0].keys():
|
||||||
|
@ -154,9 +148,7 @@ def concat_padded_tensors(
|
||||||
tensors_to_concat.append(tensor)
|
tensors_to_concat.append(tensor)
|
||||||
|
|
||||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||||
if "attention_mask" not in result:
|
return TensorDict(result, batch_size=new_batch_size)
|
||||||
result["attention_mask"] = attn_mask
|
|
||||||
return TensorDict(result, batch_size=[len(lens)])
|
|
||||||
|
|
||||||
|
|
||||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||||
|
@ -231,7 +223,7 @@ def pack_tensor_dict(data: TensorDict):
|
||||||
cu_seqlens = torch.cumsum(lens, dim=0)
|
cu_seqlens = torch.cumsum(lens, dim=0)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
total_length = cu_seqlens[-1].item()
|
total_length = int(cu_seqlens[-1].item())
|
||||||
# Pack tensors
|
# Pack tensors
|
||||||
packed_data = {}
|
packed_data = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
|
@ -246,7 +238,7 @@ def pack_tensor_dict(data: TensorDict):
|
||||||
and value.shape[1] == seq_len
|
and value.shape[1] == seq_len
|
||||||
):
|
):
|
||||||
packed_tensor = torch.empty(
|
packed_tensor = torch.empty(
|
||||||
total_length, *value.shape[2:], dtype=value.dtype, device=value.device
|
(total_length, *value.shape[2:]), dtype=value.dtype, device=value.device
|
||||||
)
|
)
|
||||||
# Fill the packed tensor with values from the original tensor
|
# Fill the packed tensor with values from the original tensor
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
|
@ -363,7 +355,7 @@ def split_padded_tensor_dict_into_mb_list(
|
||||||
mbs=results,
|
mbs=results,
|
||||||
mb_spec=mb_spec,
|
mb_spec=mb_spec,
|
||||||
forward_indices=forward_indices,
|
forward_indices=forward_indices,
|
||||||
backward_indices=backward_indices,
|
backward_indices=backward_indices.tolist(),
|
||||||
group_lens=group_lens,
|
group_lens=group_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -121,19 +121,24 @@ def masked_normalization(
|
||||||
|
|
||||||
def ppo_actor_loss_fn(
|
def ppo_actor_loss_fn(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
|
proximal_logprobs: torch.Tensor,
|
||||||
old_logprobs: torch.Tensor,
|
old_logprobs: torch.Tensor,
|
||||||
advantages: torch.Tensor,
|
advantages: torch.Tensor,
|
||||||
eps_clip: float,
|
eps_clip: float,
|
||||||
loss_mask: torch.Tensor,
|
loss_mask: torch.Tensor,
|
||||||
c_clip: Optional[float] = None,
|
c_clip: Optional[float] = None,
|
||||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
|
||||||
behav_imp_weight_cap: Optional[float] = None,
|
behav_imp_weight_cap: Optional[float] = None,
|
||||||
) -> Tuple[torch.Tensor, Dict]:
|
) -> Tuple[torch.Tensor, Dict]:
|
||||||
denorm_logprobs = (
|
"""
|
||||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
When decoupled loss is disabled:
|
||||||
)
|
1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp;
|
||||||
|
2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend.
|
||||||
|
|
||||||
|
When decoupled loss is enabled, proximal_logprobs is the recomputed logp,
|
||||||
|
old_logprobs is produced by the inference engine.
|
||||||
|
"""
|
||||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
|
||||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||||
pg_loss1 = -advantages * ratio
|
pg_loss1 = -advantages * ratio
|
||||||
pg_loss2 = -advantages * clipped_ratio
|
pg_loss2 = -advantages * clipped_ratio
|
||||||
|
@ -146,17 +151,16 @@ def ppo_actor_loss_fn(
|
||||||
pg_loss = torch.min(pg_loss, pg_loss3)
|
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||||
else:
|
else:
|
||||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||||
if proximal_logprobs is not None:
|
behav_kl = proximal_logprobs - old_logprobs
|
||||||
behav_kl = proximal_logprobs - old_logprobs
|
behav_imp_weight = behav_kl.exp()
|
||||||
behav_imp_weight = behav_kl.exp()
|
behav_mask = (
|
||||||
behav_mask = (
|
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
if behav_imp_weight_cap is not None
|
||||||
if behav_imp_weight_cap is not None
|
else loss_mask
|
||||||
else loss_mask
|
)
|
||||||
)
|
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
pg_loss = pg_loss * behav_imp_weight
|
||||||
pg_loss = pg_loss * behav_imp_weight
|
|
||||||
logging_loss = pg_loss.detach()
|
logging_loss = pg_loss.detach()
|
||||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||||
clip_mask.logical_and_(loss_mask)
|
clip_mask.logical_and_(loss_mask)
|
||||||
|
@ -164,7 +168,7 @@ def ppo_actor_loss_fn(
|
||||||
stat = dict(
|
stat = dict(
|
||||||
loss=logging_loss,
|
loss=logging_loss,
|
||||||
importance_weight=ratio.detach(),
|
importance_weight=ratio.detach(),
|
||||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
approx_kl=(logprobs - proximal_logprobs).detach(),
|
||||||
clip_mask=clip_mask,
|
clip_mask=clip_mask,
|
||||||
dual_clip_mask=dual_clip_mask,
|
dual_clip_mask=dual_clip_mask,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,45 +3,74 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from realhf.base import logging
|
||||||
|
|
||||||
DEFAULT_RETRIES = 1
|
DEFAULT_RETRIES = 1
|
||||||
DEFAULT_REQUEST_TIMEOUT = 3600
|
DEFAULT_REQUEST_TIMEOUT = 3600
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_connector():
|
||||||
|
return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True)
|
||||||
|
|
||||||
|
|
||||||
async def arequest_with_retry(
|
async def arequest_with_retry(
|
||||||
addr: str,
|
addr: str,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
|
session: aiohttp.ClientSession | None = None,
|
||||||
method: str = "POST",
|
method: str = "POST",
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
) -> aiohttp.ClientResponse:
|
verbose=False,
|
||||||
|
) -> Dict:
|
||||||
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
|
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
|
||||||
last_exception = None
|
last_exception = None
|
||||||
max_retries = max_retries or DEFAULT_RETRIES
|
max_retries = max_retries or DEFAULT_RETRIES
|
||||||
base_url = f"http://{addr}"
|
base_url = f"http://{addr}"
|
||||||
url = f"{base_url}{endpoint}"
|
url = f"{base_url}{endpoint}"
|
||||||
|
|
||||||
|
timeo = aiohttp.ClientTimeout(
|
||||||
|
total=timeout,
|
||||||
|
sock_connect=timeout,
|
||||||
|
connect=timeout,
|
||||||
|
)
|
||||||
|
if session is None:
|
||||||
|
_session = aiohttp.ClientSession(
|
||||||
|
timeout=timeo,
|
||||||
|
read_bufsize=1024 * 1024 * 10,
|
||||||
|
connector=get_default_connector(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_session = session
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(
|
if verbose:
|
||||||
timeout=aiohttp.ClientTimeout(
|
logger.info("enter client session, start sending requests")
|
||||||
total=timeout,
|
if method.upper() == "GET":
|
||||||
sock_connect=timeout,
|
ctx = _session.get(url, timeout=timeo)
|
||||||
)
|
elif method.upper() == "POST":
|
||||||
) as session:
|
ctx = _session.post(url, json=payload, timeout=timeo)
|
||||||
if method.upper() == "GET":
|
elif method.upper() == "PUT":
|
||||||
response = await session.get(url)
|
ctx = _session.put(url, json=payload, timeout=timeo)
|
||||||
elif method.upper() == "POST":
|
elif method.upper() == "DELETE":
|
||||||
response = await session.post(url, json=payload)
|
ctx = _session.delete(url, timeout=timeo)
|
||||||
elif method.upper() == "PUT":
|
else:
|
||||||
response = await session.put(url, json=payload)
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
elif method.upper() == "DELETE":
|
async with ctx as response:
|
||||||
response = await session.delete(url)
|
if verbose:
|
||||||
else:
|
logger.info("http requests return")
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return await response.json()
|
res = await response.json()
|
||||||
|
if verbose:
|
||||||
|
logger.info("get http result")
|
||||||
|
if session is None:
|
||||||
|
await _session.close()
|
||||||
|
return res
|
||||||
except (
|
except (
|
||||||
aiohttp.ClientError,
|
aiohttp.ClientError,
|
||||||
aiohttp.ClientResponseError,
|
aiohttp.ClientResponseError,
|
||||||
|
@ -51,6 +80,10 @@ async def arequest_with_retry(
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
continue
|
continue
|
||||||
|
if session is None:
|
||||||
|
await _session.close()
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
f"Failed after {max_retries} retries each. "
|
||||||
|
f"Payload: {payload}. Addr: {addr}. Endpoint: {endpoint}. "
|
||||||
|
f"Last error: {last_exception}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizerFast
|
||||||
from arealite.api.cli_args import GenerationHyperparameters
|
from arealite.api.cli_args import GenerationHyperparameters
|
||||||
from arealite.api.io_struct import LLMRequest
|
from arealite.api.io_struct import LLMRequest
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
from arealite.api.workflow_api import RolloutWorkflow
|
||||||
from arealite.utils.padding import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
|
||||||
|
|
||||||
class RLVRWorkflow(RolloutWorkflow):
|
class RLVRWorkflow(RolloutWorkflow):
|
||||||
|
@ -61,7 +61,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
versions=torch.tensor(versions).unsqueeze(0),
|
versions=torch.tensor(versions).unsqueeze(0),
|
||||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||||
# reward
|
# reward
|
||||||
rewards=torch.tensor([reward]),
|
rewards=torch.tensor([float(reward)]),
|
||||||
)
|
)
|
||||||
results.append(TensorDict(res, batch_size=[1]))
|
results.append(TensorDict(res, batch_size=[1]))
|
||||||
|
|
||||||
|
|
|
@ -152,11 +152,7 @@ def main_grpo():
|
||||||
|
|
||||||
with stats_tracker.record_timing("rollout"):
|
with stats_tracker.record_timing("rollout"):
|
||||||
if config.async_training:
|
if config.async_training:
|
||||||
batch = rollout.prepare_batch(
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
data_generator,
|
|
||||||
train_dataloader,
|
|
||||||
workflow=workflow,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
data = next(data_generator)
|
data = next(data_generator)
|
||||||
|
@ -210,8 +206,9 @@ def main_grpo():
|
||||||
actor.upload_weights(meta)
|
actor.upload_weights(meta)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future.result()
|
future.result()
|
||||||
rollout.set_version(global_step + 1)
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
rollout.set_version(global_step + 1)
|
||||||
|
|
||||||
with stats_tracker.record_timing("save"):
|
with stats_tracker.record_timing("save"):
|
||||||
saver.save(actor, epoch, step, global_step)
|
saver.save(actor, epoch, step, global_step)
|
||||||
|
|
|
@ -1360,7 +1360,9 @@ class RayNameResolveRepository:
|
||||||
|
|
||||||
def make_repository(args: "NameResolveConfig"):
|
def make_repository(args: "NameResolveConfig"):
|
||||||
if args.type == "nfs":
|
if args.type == "nfs":
|
||||||
return NfsNameRecordRepository(args.nfs_record_root)
|
repo = NfsNameRecordRepository(args.nfs_record_root)
|
||||||
|
os.makedirs(repo.record_root, exist_ok=True)
|
||||||
|
return repo
|
||||||
elif args.type == "etcd3":
|
elif args.type == "etcd3":
|
||||||
host, port = args.etcd3_addr.split(":")
|
host, port = args.etcd3_addr.split(":")
|
||||||
return Etcd3NameRecordRepository(host=host, port=int(port))
|
return Etcd3NameRecordRepository(host=host, port=int(port))
|
||||||
|
|
Loading…
Reference in New Issue