mirror of https://github.com/inclusionAI/AReaL
[lite] [refactor] Add GSM8k GRPO example. (#179)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * fix ci * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
4490b117e4
commit
e13db01f67
|
@ -92,7 +92,7 @@ def main_grpo():
|
|||
future.result()
|
||||
|
||||
# synchronous rollout
|
||||
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
# or asynchronous rollout with filtering and off-policyness control
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||
|
@ -697,7 +697,7 @@ reward = TrainController(Critic())
|
|||
rollout_controller = RolloutController(...)
|
||||
for _ in range(epochs):
|
||||
for _ in range(steps_per_epoch):
|
||||
data = rollout_controller.rollout(prompt)
|
||||
data = rollout_controller.rollout_batch(prompt)
|
||||
data['reward'] = reward.compute_values(data)
|
||||
...
|
||||
```
|
||||
|
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import uvloop
|
||||
|
||||
|
@ -12,7 +12,6 @@ from hydra import initialize as hydra_init
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from arealite.utils.fs import get_user_tmp
|
||||
from realhf.api.cli_args import OptimizerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -84,6 +83,63 @@ class GenerationHyperparameters:
|
|||
|
||||
|
||||
# Train Engine Configs
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for model optimization during training.
|
||||
Note:
|
||||
Set type to "empty" for models that won't be trained.
|
||||
"""
|
||||
|
||||
type: str = field(
|
||||
default="adam",
|
||||
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
||||
)
|
||||
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
||||
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
||||
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
||||
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
||||
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
|
||||
min_lr_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Minimum learning rate ratio after annealing",
|
||||
},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="constant",
|
||||
metadata={
|
||||
"help": "Learning rate scheduler type",
|
||||
"choices": ["linear", "cosine", "constant"],
|
||||
},
|
||||
)
|
||||
warmup_steps_proportion: float = field(
|
||||
default=0.001,
|
||||
metadata={
|
||||
"help": "Proportion of training steps for warmup",
|
||||
},
|
||||
)
|
||||
offload: bool = field(
|
||||
default=False, metadata={"help": "Enable optimizer state offloading"}
|
||||
)
|
||||
initial_loss_scale: float = field(
|
||||
default=2**32, metadata={"help": "Initial loss scaling factor"}
|
||||
)
|
||||
min_loss_scale: float = field(
|
||||
default=1.0, metadata={"help": "Minimum loss scaling factor"}
|
||||
)
|
||||
loss_scale_window: float = field(
|
||||
default=5, metadata={"help": "Window size for loss scaling adjustment"}
|
||||
)
|
||||
hysteresis: int = field(
|
||||
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
|
||||
)
|
||||
gradient_clipping: float = field(
|
||||
default=1.0, metadata={"help": "Gradient clipping threshold"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPWrapPolicy:
|
||||
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
|
||||
|
@ -135,10 +191,11 @@ class TrainEngineConfig:
|
|||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
disable_dropout: bool = field(default=False)
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
|
||||
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
|
||||
optimizer: Optional[OptimizerConfig] = field(
|
||||
default=None, metadata={"help": "Optimizer configuration"}
|
||||
)
|
||||
|
@ -147,12 +204,94 @@ class TrainEngineConfig:
|
|||
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOActorConfig(TrainEngineConfig):
|
||||
# Core PPO/GRPO Parameters
|
||||
group_size: int = field(
|
||||
default=1, metadata={"help": "Number of sequences in each group"}
|
||||
)
|
||||
group_adv_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize advantages within each prompt group rather than globally"
|
||||
},
|
||||
)
|
||||
ppo_n_minibatches: int = field(
|
||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||
)
|
||||
eps_clip: float = field(
|
||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||
)
|
||||
c_clip: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "Temperature during generation."}
|
||||
)
|
||||
# Reward
|
||||
group_reward_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||
},
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||
)
|
||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||
reward_clip: float = field(
|
||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||
)
|
||||
mask_no_eos_with_zero: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||
},
|
||||
)
|
||||
|
||||
# Advantage Estimation
|
||||
discount: float = field(
|
||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||
)
|
||||
gae_lambda: float = field(
|
||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||
|
||||
# Asynchronous RL
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
https://github.com/sgl-project/sglang for detailed documentation.
|
||||
"""
|
||||
|
||||
model_path: str = ""
|
||||
random_seed: int = 1
|
||||
skip_tokenizer_init: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
|
@ -188,10 +327,8 @@ class SGLangConfig:
|
|||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# logging
|
||||
log_level: str = "warning"
|
||||
log_level_http: Optional[str] = "warning"
|
||||
|
@ -207,21 +344,19 @@ class SGLangConfig:
|
|||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
host,
|
||||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
):
|
||||
args = SGLangConfig.build_args(
|
||||
sglang_config=sglang_config,
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
base_gpu_id=base_gpu_id,
|
||||
host=host,
|
||||
port=port,
|
||||
dist_init_addr=dist_init_addr,
|
||||
served_model_name=served_model_name,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
|
||||
# convert to flags
|
||||
|
@ -240,48 +375,54 @@ class SGLangConfig:
|
|||
@staticmethod
|
||||
def build_args(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
host,
|
||||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.base import pkg_version
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
args["random_seed"] = seeding.get_seed()
|
||||
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
||||
args = dict(
|
||||
host=host,
|
||||
model_path=model_path,
|
||||
port=port,
|
||||
# Model and tokenizer
|
||||
tokenizer_path=model_path,
|
||||
tokenizer_path=sglang_config.model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
served_model_name=served_model_name,
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
# Other runtime options
|
||||
tp_size=tp_size,
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=base_gpu_id,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
# initialization addresses and ports
|
||||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
|
||||
if pkg_version.is_version_less("sglang", "0.4.4"):
|
||||
sglang_version = pkg_version.get_version("sglang")
|
||||
if sglang_version:
|
||||
version_less_than_0_4_4 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
|
||||
)
|
||||
version_less_than_0_4_3 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
|
||||
)
|
||||
elif pkg_version.is_available("sglang"):
|
||||
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
|
||||
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||
)
|
||||
if version_less_than_0_4_4:
|
||||
args.pop("log_requests_level")
|
||||
if pkg_version.is_version_less("sglang", "0.4.3"):
|
||||
if version_less_than_0_4_3:
|
||||
args.pop("enable_nccl_nvls")
|
||||
args.pop("triton_attention_num_kv_splits")
|
||||
args.pop("cuda_graph_bs")
|
||||
|
@ -294,8 +435,8 @@ class SGLangConfig:
|
|||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
max_concurrent_rollouts: None | int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -318,28 +459,20 @@ class InferenceEngineConfig:
|
|||
"the request will not be accepted.",
|
||||
},
|
||||
)
|
||||
# Used by remote inference engines.
|
||||
server_addrs: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of server addresses for inference."},
|
||||
)
|
||||
enable_rollout_tracing: bool = field(default=False)
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
)
|
||||
setup_timeout: float = field(default=90.0)
|
||||
request_timeout: float = field(
|
||||
default=30.0, metadata={"help": "Timeout for HTTP requests."}
|
||||
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
||||
)
|
||||
request_retries: int = field(
|
||||
default=3, metadata={"help": "Number of retries for failed requests."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangEngineConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Timer:
|
||||
experiment_name: str = MISSING
|
||||
|
@ -569,39 +702,58 @@ class BaseExperimentConfig:
|
|||
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
|
||||
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
|
||||
|
||||
server_only: bool = False
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(BaseExperimentConfig):
|
||||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
@dataclass
|
||||
class GRPOConfig(BaseExperimentConfig):
|
||||
async_training: bool = field(default=True)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
|
||||
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
|
||||
|
||||
def parse_cli_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
args, overrides = parser.parse_known_args(argv)
|
||||
|
||||
# Initialize hydra config
|
||||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
# hydra only recognize relative paths
|
||||
relpath = Path(
|
||||
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
|
||||
)
|
||||
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
config_name=str(relpath.name).split(".yaml")[0],
|
||||
overrides=overrides,
|
||||
)
|
||||
return cfg, config_file
|
||||
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
# Merge with the default configuration.
|
||||
# The yaml and commandline can omit some default values defined in python dataclasses.
|
||||
default_cfg = OmegaConf.structured(config_cls)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls):
|
||||
cfg, config_file = parse_cli_args(argv)
|
||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
# Setup environment
|
||||
from realhf.base import constants, name_resolve
|
||||
|
||||
|
|
|
@ -4,7 +4,9 @@ from dataclasses import dataclass, field
|
|||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
|
@ -40,6 +42,11 @@ class TrainEngine(abc.ABC):
|
|||
"""Initialize environments for distributed training and load models."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def parallelism_group(self) -> dist.ProcessGroup:
|
||||
"""The global communication group of this engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_scheduling_config(self) -> Scheduling:
|
||||
"""Get the scheduling configuration for the engine, e.g., image, cpu/gpu/memory size."""
|
||||
raise NotImplementedError()
|
||||
|
@ -77,9 +84,9 @@ class TrainEngine(abc.ABC):
|
|||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Update the model with a batch of data and a loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -87,9 +94,9 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate the model using the forward pass and loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -97,9 +104,9 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
||||
|
@ -127,12 +134,33 @@ class InferenceEngine(abc.ABC):
|
|||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: float) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
"""Asynchronously submit and wait until a full batch is ready."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pause(self):
|
||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def resume(self):
|
||||
"""Resume request submission for async rollout."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
|
|||
@dataclass
|
||||
class LLMRequest:
|
||||
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text: Optional[str] = None
|
||||
input_ids: List[int] = field(default_factory=list)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
|
@ -28,7 +27,6 @@ class LLMRequest:
|
|||
@dataclass
|
||||
class LLMResponse:
|
||||
# outputs
|
||||
completions: str
|
||||
input_tokens: List[int] = field(default_factory=list)
|
||||
output_tokens: List[int] = field(default_factory=list)
|
||||
output_logprobs: List[float] = field(default_factory=list)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -32,7 +33,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
|
@ -45,9 +46,10 @@ from arealite.utils.fsdp import (
|
|||
fsdp2_load_full_state_dict,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
from realhf.base import constants, logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
|
@ -68,6 +70,8 @@ class FSDPEngine(TrainEngine):
|
|||
self.cpu_offload = None
|
||||
# initialization
|
||||
self.initialized = False
|
||||
self.own_global_group = False
|
||||
self._parallelism_group = None
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
# TODO: Handle the case when WORLD_SIZE is not set in launcher
|
||||
|
@ -78,6 +82,11 @@ class FSDPEngine(TrainEngine):
|
|||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
@property
|
||||
def parallelism_group(self) -> dist.ProcessGroup:
|
||||
assert self.initialized
|
||||
return self._parallelism_group
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
# Initialize distributed enviroments and load model.
|
||||
assert addr is None, "FSDPEngine does not support remote initialization."
|
||||
|
@ -89,29 +98,54 @@ class FSDPEngine(TrainEngine):
|
|||
"""Initialize distributed communication and model."""
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
dist.init_process_group(backend="nccl")
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
||||
)
|
||||
self.own_global_group = True
|
||||
self._parallelism_group = dist.new_group()
|
||||
|
||||
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
param_dtype=dtype,
|
||||
reduce_dtype=torch.float32,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
@ -129,23 +163,14 @@ class FSDPEngine(TrainEngine):
|
|||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
tik = time.perf_counter()
|
||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
||||
self.model = model
|
||||
|
||||
if not self.config.init_from_scratch:
|
||||
# Load model from a initial checkpoint path,
|
||||
# which should only be a huggingface checkpoint.
|
||||
load_meta = SaveLoadMeta(
|
||||
path=self.config.path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=None,
|
||||
base_model_path=self.config.path,
|
||||
)
|
||||
self.load(load_meta)
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
tik = time.perf_counter()
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
|
@ -189,6 +214,7 @@ class FSDPEngine(TrainEngine):
|
|||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
|
@ -199,6 +225,9 @@ class FSDPEngine(TrainEngine):
|
|||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
dist.destroy_process_group(self.parallelism_group)
|
||||
if self.own_global_group:
|
||||
dist.destroy_process_group()
|
||||
self.initialized = False
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
|
@ -300,7 +329,9 @@ class FSDPEngine(TrainEngine):
|
|||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
)
|
||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
|
@ -323,15 +354,19 @@ class FSDPEngine(TrainEngine):
|
|||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
self.config.mb_spec,
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
@ -356,9 +391,10 @@ class FSDPEngine(TrainEngine):
|
|||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
|
|
|
@ -15,7 +15,7 @@ from transformers import (
|
|||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
SaveLoadMeta,
|
||||
|
@ -81,7 +81,7 @@ class HFEngine(TrainEngine):
|
|||
torch.cuda.set_device(local_rank)
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
|
|
|
@ -0,0 +1,334 @@
|
|||
import functools
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.data import split_padded_tensor_dict_into_mb_list
|
||||
from arealite.utils.functional import (
|
||||
gather_logprobs,
|
||||
gather_logprobs_entropy,
|
||||
masked_normalization,
|
||||
ppo_actor_loss_fn,
|
||||
)
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
class PPOActor:
|
||||
|
||||
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
|
||||
self.config = config
|
||||
self.engine = engine
|
||||
|
||||
self.reward_bias = config.reward_bias
|
||||
self.reward_scaling = config.reward_scaling
|
||||
self.reward_clip = config.reward_clip
|
||||
|
||||
self.group_reward_norm = config.group_reward_norm
|
||||
self.group_adv_norm = config.group_adv_norm
|
||||
self.group_size = config.group_size
|
||||
|
||||
self.kl_ctl = config.kl_ctl
|
||||
|
||||
self.adv_norm = config.adv_norm
|
||||
self.discount = config.discount
|
||||
self.gae_lambda = config.gae_lambda
|
||||
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero
|
||||
|
||||
self.temperature = config.temperature
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logp(
|
||||
self,
|
||||
data: TensorDict,
|
||||
temperature: Optional[float] = None,
|
||||
) -> torch.Tensor | None:
|
||||
|
||||
def calc_logprobs(logits, input_data):
|
||||
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
|
||||
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
|
||||
return logprobs
|
||||
|
||||
self.engine.eval()
|
||||
return self.engine.forward(
|
||||
input_=data,
|
||||
post_hook=calc_logprobs,
|
||||
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
|
||||
)
|
||||
|
||||
def compute_advantages(self, data: TensorDict) -> None:
|
||||
bs = data["input_ids"].shape[0]
|
||||
max_seqlen = data["input_ids"].shape[1]
|
||||
batch_indices = torch.arange(
|
||||
bs, device=data["input_ids"].device, dtype=torch.long
|
||||
)
|
||||
|
||||
# Compute rewards using the reward function in synchronous RLVR pipeline.
|
||||
reward_score = data["rewards"]
|
||||
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
|
||||
reward_score = torch.clip(
|
||||
reward_score, max=self.reward_clip, min=-self.reward_clip
|
||||
)
|
||||
if self.group_reward_norm:
|
||||
for i in range(bs // self.group_size):
|
||||
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
||||
r = reward_score[s]
|
||||
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
|
||||
|
||||
loss_mask = data["loss_mask"].float()
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||
# Apply the mask to log probabilities.
|
||||
if not self.config.use_decoupled_loss and self.config.recompute_logprob:
|
||||
# Overwrite logprobs produced by the inference engine
|
||||
old_logp = data["logprobs"] = data["prox_logp"]
|
||||
else:
|
||||
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
|
||||
if not self.config.use_decoupled_loss:
|
||||
# prox logp not available, use inferenced logp
|
||||
data["prox_logp"] = old_logp
|
||||
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
|
||||
ref_logp *= loss_mask
|
||||
old_logp *= loss_mask
|
||||
|
||||
# Compute KL-regularized rewards.
|
||||
attn_mask = data["attention_mask"]
|
||||
seqlens = attn_mask.sum(-1).long()
|
||||
seq_no_eos_mask = seqlens == attn_mask.shape[1]
|
||||
rewards = -self.kl_ctl * (old_logp - ref_logp)
|
||||
kl_rewards = rewards.clone()
|
||||
# KL rewards at the next token after eos is zero.
|
||||
rewards[batch_indices, seqlens - 1] = 0
|
||||
indices = torch.clip(seqlens - 2, min=0)
|
||||
if self.mask_no_eos_with_zero:
|
||||
rewards[batch_indices, indices] += torch.where(
|
||||
seq_no_eos_mask, 0, reward_score
|
||||
)
|
||||
else:
|
||||
rewards[batch_indices, indices] += reward_score
|
||||
|
||||
# Compute GAE.
|
||||
if "values" not in data:
|
||||
values = torch.zeros_like(rewards)
|
||||
else:
|
||||
values = data["values"]
|
||||
advantages_reversed = []
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(max_seqlen - 1)):
|
||||
nextvalues = values[:, t + 1]
|
||||
if t == max_seqlen - 2:
|
||||
nextvalues *= seq_no_eos_mask
|
||||
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
||||
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages_reversed.append(
|
||||
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
||||
)
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
# Optionally perform advantage normalization.
|
||||
if self.adv_norm:
|
||||
if self.group_adv_norm:
|
||||
adv_list = []
|
||||
for i in range(0, bs, self.group_size):
|
||||
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
||||
adv = advantages[s]
|
||||
m = loss_mask[s]
|
||||
adv_list.append(masked_normalization(adv, m, all_reduce=False))
|
||||
advantages = torch.cat(adv_list, 0)
|
||||
else:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
|
||||
# Store data in the dict.
|
||||
data["advantages"] = advantages
|
||||
data["kl_rewards"] = kl_rewards
|
||||
data["tot_rewards"] = rewards
|
||||
data["loss_mask"] = loss_mask
|
||||
# because we have rolled old_logp by -1
|
||||
data["logprobs"] = old_logp
|
||||
|
||||
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
|
||||
attn_mask = data["attention_mask"]
|
||||
loss_mask = data["loss_mask"]
|
||||
reward_score = data["rewards"]
|
||||
seqlens = attn_mask.sum(-1)
|
||||
|
||||
all_stats = []
|
||||
########## Logging code starts ##########
|
||||
result_denominators = {
|
||||
"correct_n_seqs": (reward_score > 0).bool(),
|
||||
"incorrect_n_seqs": (reward_score <= 0).bool(),
|
||||
}
|
||||
global_denominators = dict(
|
||||
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
||||
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
**result_denominators,
|
||||
)
|
||||
stats_tracker.denominator(**global_denominators)
|
||||
stats_tracker.stat(
|
||||
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
|
||||
)
|
||||
stats_tracker.stat(
|
||||
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
advantages=data["advantages"],
|
||||
kl_rewards=data["kl_rewards"],
|
||||
final_reward=data["tot_rewards"],
|
||||
)
|
||||
stats_tracker.stat(**stats, denominator="n_valid_tokens")
|
||||
|
||||
prompt_lens = []
|
||||
prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1)
|
||||
seq_stats = dict(
|
||||
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
|
||||
task_reward=reward_score.float(),
|
||||
prompt_len=prompt_lens.float(),
|
||||
seq_len=seqlens.float(),
|
||||
)
|
||||
stats_tracker.stat(**seq_stats, denominator="n_seqs")
|
||||
scalars = dict(
|
||||
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
|
||||
eps_clip=self.config.eps_clip,
|
||||
)
|
||||
if self.config.c_clip is not None:
|
||||
scalars["c_clip"] = self.config.c_clip
|
||||
scalars["use_dual_clip"] = 1
|
||||
else:
|
||||
scalars["use_dual_clip"] = 0
|
||||
if self.config.behav_imp_weight_cap is not None:
|
||||
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
|
||||
stats_tracker.scalar(**scalars)
|
||||
|
||||
global_stats = stats_tracker.export(reduce_group=self.engine.parallelism_group)
|
||||
for k in global_denominators:
|
||||
keys = list(global_stats.keys())
|
||||
for k2 in keys:
|
||||
if k2.endswith(k):
|
||||
global_stats.pop(k2)
|
||||
########## Logging code ends ##########
|
||||
|
||||
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
||||
data.pop(key, None)
|
||||
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
|
||||
self.engine.train()
|
||||
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
||||
data,
|
||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||
)
|
||||
for mb in mb_inputs.mbs:
|
||||
train_stat = self.engine.train_batch(
|
||||
mb,
|
||||
loss_fn=functools.partial(
|
||||
grpo_loss_fn,
|
||||
temperature=self.temperature,
|
||||
eps_clip=self.config.eps_clip,
|
||||
c_clip=self.config.c_clip,
|
||||
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
|
||||
),
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
stats_tracker.scalar(**train_stat)
|
||||
all_stats.append(
|
||||
stats_tracker.export(reduce_group=self.engine.parallelism_group)
|
||||
)
|
||||
all_stats[0].update(global_stats)
|
||||
return all_stats
|
||||
|
||||
|
||||
class FSDPPPOActor(FSDPEngine):
|
||||
|
||||
def __init__(self, config: PPOActorConfig):
|
||||
super().__init__(config)
|
||||
self.actor = PPOActor(config, self)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
|
||||
return self.actor.compute_logp(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_advantages(self, *args, **kwargs) -> None:
|
||||
self.actor.compute_advantages(*args, **kwargs)
|
||||
|
||||
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
|
||||
return self.actor.ppo_update(*args, **kwargs)
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
temperature: float,
|
||||
eps_clip: float,
|
||||
c_clip: float | None,
|
||||
behav_imp_weight_cap: float | None,
|
||||
):
|
||||
"""Loss function for actor step, all inputs should be splitted into
|
||||
pipeline micro batches, returns loss and logging stats."""
|
||||
input_ids = input_data["input_ids"]
|
||||
old_logp = input_data["logprobs"]
|
||||
advantages = input_data["advantages"]
|
||||
loss_mask = input_data["loss_mask"].bool()
|
||||
prox_logp = input_data["prox_logp"]
|
||||
|
||||
logprobs, entropy = gather_logprobs_entropy(
|
||||
logits, torch.roll(input_ids, shifts=-1, dims=-1), temperature
|
||||
)
|
||||
entropy = entropy.detach()
|
||||
loss, stat = ppo_actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
|
||||
# Log training statistics
|
||||
stats_tracker.denominator(
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
clipped_tokens=stat["clip_mask"],
|
||||
dual_clipped_tokens=stat["dual_clip_mask"],
|
||||
)
|
||||
|
||||
stats_tracker.stat(
|
||||
importance_weight=stat["importance_weight"],
|
||||
approx_kl=stat["approx_kl"],
|
||||
new_logp=logprobs.detach(),
|
||||
old_logp=old_logp,
|
||||
entropy=entropy.float(),
|
||||
actor_loss=stat["loss"],
|
||||
clip_ratio=stat["clip_mask"].float(),
|
||||
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=stat["behave_imp_weight"],
|
||||
behave_approx_kl=stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
clip_mask = stat["clip_mask"]
|
||||
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
||||
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
||||
stats_tracker.stat(
|
||||
clipped_new_logp=clipped_new_logp,
|
||||
clipped_old_logp=clipped_old_logp,
|
||||
denominator="clipped_tokens",
|
||||
)
|
||||
return loss
|
|
@ -20,7 +20,7 @@ class LMEngine:
|
|||
return self.engine.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
|
||||
def evaluate_lm(self, data):
|
||||
|
@ -28,7 +28,7 @@ class LMEngine:
|
|||
self.engine.eval_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
|
||||
|
||||
|
@ -49,26 +49,26 @@ def compute_packed_sft_loss(
|
|||
) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
prompt_mask = input_["prompt_mask"].bool()
|
||||
loss_mask = input_["loss_mask"].bool()
|
||||
|
||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(prompt_mask, 0, logprobs)
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(loss_mask, logprobs, 0)
|
||||
|
||||
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
|
||||
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
|
||||
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
|
||||
|
@ -78,8 +78,8 @@ def compute_packed_sft_loss(
|
|||
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
|
||||
),
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=prompt_mask.logical_not(),
|
||||
prompt_tokens=prompt_mask,
|
||||
n_valid_tokens=loss_mask,
|
||||
prompt_tokens=loss_mask.logical_not(),
|
||||
)
|
||||
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
|
||||
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
|
||||
|
|
|
@ -1,23 +1,30 @@
|
|||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.http import arequest_with_retry
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -30,7 +37,7 @@ if pkg_version.is_available("sglang"):
|
|||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.1
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
|
@ -46,26 +53,57 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
# Maintain the addresses for the recent 128 requests
|
||||
self.rid_queue = []
|
||||
|
||||
self.addresses = config.server_addrs
|
||||
self.server_idx = 0
|
||||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
logger.info("Waiting for server ready...")
|
||||
for addr in self.addresses:
|
||||
self._wait_for_server(addr)
|
||||
logger.info("Servers are all ready!")
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 10
|
||||
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||
self.input_queue = Queue(maxsize=qsize)
|
||||
self.output_queue = Queue(maxsize=qsize)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
self.paused = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
self._version = 0
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
|
||||
def _wait_for_server(self, address):
|
||||
base_url = f"http://{address}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < self.config.setup_timeout:
|
||||
if self.check_health(base_url):
|
||||
return
|
||||
time.sleep(1)
|
||||
raise RuntimeError("server launch failed")
|
||||
|
||||
def check_health(self, base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.executor = ProcessPoolExecutor(max_workers=1)
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
def destroy(self):
|
||||
self.executor.shutdown()
|
||||
self.exiting.set()
|
||||
self.rollout_thread.join()
|
||||
|
||||
|
@ -85,79 +123,45 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
|
||||
rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
pending_data = []
|
||||
rollout_tasks = self.rollout_tasks
|
||||
rid = 0
|
||||
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Load next data from controller
|
||||
if data is None:
|
||||
while True:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.info(f"Get data from puller: {data}")
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
pending_data.append(data)
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
break
|
||||
|
||||
# Check capacity
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
cannot_rollout_reason = []
|
||||
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
|
||||
can_rollout = len(rollout_tasks) < capacity
|
||||
if not can_rollout:
|
||||
cannot_rollout_reason.append(
|
||||
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
|
||||
)
|
||||
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
expected_version = sample_cnt // self.config.consumer_batch_size
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
||||
if not can_rollout:
|
||||
logger.debug(
|
||||
f"Cannot submit new rollouts. "
|
||||
+ "\n".join(cannot_rollout_reason)
|
||||
)
|
||||
|
||||
capacity = self.get_capacity()
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None:
|
||||
while capacity > 0 and pending_data and not self.paused.is_set():
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
|
||||
)
|
||||
rollout_tasks[str(rid)] = task
|
||||
|
||||
with self.lock:
|
||||
rollout_tasks[str(rid)] = task
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
data = None
|
||||
|
||||
# Wait for rollout completion
|
||||
tasks = list(rollout_tasks.values())
|
||||
with self.lock:
|
||||
tasks = list(rollout_tasks.values())
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
|
@ -165,16 +169,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if not done:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
with self.lock:
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
try:
|
||||
self.output_queue.put_nowait(traj)
|
||||
|
@ -185,21 +192,25 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
with self.lock:
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def choose_server(self) -> str:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
|
@ -208,65 +219,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def arequest_with_retry(
|
||||
self,
|
||||
endpoint: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
method: str = "POST",
|
||||
max_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
retry_delay: float = 1.0,
|
||||
target_addr: Optional[str] = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
timeout = timeout or self.config.request_timeout
|
||||
last_exception = None
|
||||
max_retries = max_retries or self.config.request_retries
|
||||
|
||||
# Try with retries
|
||||
for _ in range(max_retries):
|
||||
if target_addr:
|
||||
addr = target_addr
|
||||
else:
|
||||
addr = self.choose_server()
|
||||
base_url = f"http://{addr}"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=30,
|
||||
sock_read=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
response = await session.get(url)
|
||||
elif method.upper() == "POST":
|
||||
response = await session.post(url, json=payload)
|
||||
elif method.upper() == "PUT":
|
||||
response = await session.put(url, json=payload)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await session.delete(url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientResponseError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
||||
)
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
# Prepare request payload
|
||||
|
@ -288,15 +240,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"text": req.text,
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
if req.text:
|
||||
payload["text"] = req.text
|
||||
else:
|
||||
payload["input_ids"] = req.input_ids
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
|
@ -324,18 +272,16 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# loop until the generation is complete
|
||||
response = await self.arequest_with_retry(
|
||||
result = await arequest_with_retry(
|
||||
addr=self.choose_server(),
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
target_addr=server_addr,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
# Parse response
|
||||
completions += result["text"]
|
||||
meta_info = result["meta_info"]
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
@ -350,12 +296,15 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += result["text"]
|
||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
sample_params["max_new_tokens"] = min(
|
||||
sample_params["max_new_tokens"],
|
||||
gconfig.max_new_tokens - len(output_tokens),
|
||||
)
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
completions=completions,
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
|
@ -365,56 +314,47 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
def update_weights(self, meta):
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
return executor.submit(self._update_weights, meta)
|
||||
|
||||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "disk":
|
||||
# Update weights from disk
|
||||
# Wait for model checkpoints of meta.version
|
||||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name, self.config.trial_name, meta.model_version
|
||||
# Use ProcessPool to bypass python GIL for running async coroutines
|
||||
fut = self.executor.submit(
|
||||
update_weights_from_disk,
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
self.addresses,
|
||||
meta.path,
|
||||
self.config.request_retries,
|
||||
self.config.request_timeout,
|
||||
)
|
||||
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = time.time_ns()
|
||||
logger.info(
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
|
||||
)
|
||||
try:
|
||||
jobs = [
|
||||
self.aupdate_weights_from_disk(addr, meta.path)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
|
||||
def callback(fut):
|
||||
self.set_version(meta.model_version)
|
||||
|
||||
fut.add_done_callback(callback)
|
||||
return fut
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
async def aupdate_weights_from_disk(self, addr, path: str):
|
||||
response = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
target_addr=addr,
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
max_concurrent_rollouts = max(
|
||||
1, self.config.max_concurrent_rollouts // world_size
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||
return capacity
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
try:
|
||||
|
@ -422,9 +362,15 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
except Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
|
||||
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
timeout = timeout or float(7 * 24 * 3600)
|
||||
while (
|
||||
accepted < count
|
||||
and not self.exiting.is_set()
|
||||
|
@ -432,14 +378,14 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept(result):
|
||||
if should_accept is None or should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted -= 1
|
||||
except Empty:
|
||||
time.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
pass
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
|
@ -450,16 +396,94 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.result_cache[:count],
|
||||
self.result_cache[count:],
|
||||
)
|
||||
return TensorDict.cat(results, dim=0)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout(
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(
|
||||
count=len(data),
|
||||
timeout=self.config.request_timeout,
|
||||
should_accept=lambda x: True,
|
||||
return self.wait(count=len(data))
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
data_generator: Iterator,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
if self.get_capacity() + dataloader.batch_size > 0:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(dataloader)
|
||||
data = next(data_generator)
|
||||
for item in data:
|
||||
self.submit(item, workflow=workflow)
|
||||
try:
|
||||
return self.wait(dataloader.batch_size, timeout=1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
def pause(self):
|
||||
self.paused.set()
|
||||
|
||||
def resume(self):
|
||||
self.paused.clear()
|
||||
|
||||
|
||||
async def aupdate_weights_from_disk(
|
||||
addr, path: str, request_retries: int, request_timeout: float
|
||||
):
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=request_retries,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
|
||||
def update_weights_from_disk(
|
||||
experiment_name,
|
||||
trial_name,
|
||||
model_version,
|
||||
addresses,
|
||||
path,
|
||||
request_retries,
|
||||
request_timeout,
|
||||
):
|
||||
async def _fn():
|
||||
# Wait for model checkpoints of meta.version
|
||||
update_name = names.update_weights_from_disk(
|
||||
experiment_name, trial_name, model_version
|
||||
)
|
||||
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = datetime.now().timestamp()
|
||||
logger.info(
|
||||
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
||||
)
|
||||
jobs = [
|
||||
aupdate_weights_from_disk(
|
||||
addr,
|
||||
path=path,
|
||||
request_retries=request_retries,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
for addr in addresses
|
||||
]
|
||||
await asyncio.gather(*jobs)
|
||||
logger.info(
|
||||
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
||||
)
|
||||
|
||||
return asyncio.run(_fn())
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
import getpass
|
||||
import os
|
||||
import re
|
||||
import signal as signal_module
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
|
||||
from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cfg
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import gpu_utils, logging, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Local Scheduler")
|
||||
|
||||
JOB_STATE_TO_PROCESS_STATUS = {
|
||||
JobState.NOT_FOUND: [],
|
||||
JobState.PENDING: [psutil.STATUS_PARKED],
|
||||
JobState.RUNNING: [
|
||||
psutil.STATUS_RUNNING,
|
||||
psutil.STATUS_SLEEPING,
|
||||
psutil.STATUS_DISK_SLEEP,
|
||||
psutil.STATUS_TRACING_STOP,
|
||||
psutil.STATUS_WAKING,
|
||||
psutil.STATUS_WAITING,
|
||||
psutil.STATUS_LOCKED,
|
||||
psutil.STATUS_IDLE,
|
||||
],
|
||||
JobState.COMPLETED: [
|
||||
psutil.STATUS_DEAD,
|
||||
psutil.STATUS_STOPPED,
|
||||
psutil.STATUS_ZOMBIE,
|
||||
],
|
||||
JobState.FAILED: [],
|
||||
JobState.CANCELLED: [],
|
||||
}
|
||||
|
||||
PROCESS_STATUS_TO_JOB_STATE = {}
|
||||
for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items():
|
||||
for process_status in process_statuses:
|
||||
PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state
|
||||
|
||||
|
||||
def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None):
|
||||
if signal is None:
|
||||
signal = signal_module.SIGKILL
|
||||
if isinstance(signal, str):
|
||||
signal = getattr(signal_module, signal)
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
terminate_process_and_children(child.pid)
|
||||
parent.send_signal(signal)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
class LocalLauncher:
|
||||
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
self._jobs: Dict[str, subprocess.Popen] = {}
|
||||
self._job_counter: Dict[str, int] = defaultdict(int)
|
||||
self._job_states = {}
|
||||
|
||||
self._gpu_counter = 0
|
||||
self._cuda_devices: List[str] = os.environ.get(
|
||||
"CUDA_VISIBLE_DEVICES", ",".join(map(str, range(gpu_utils.gpu_count())))
|
||||
).split(",")
|
||||
if len(self._cuda_devices) < 1:
|
||||
raise RuntimeError(
|
||||
f"Local mode can only run when there is at least one GPU. "
|
||||
f"CUDA_VISIBLE_DEVICES is currently set to {os.environ['CUDA_VISIBLE_DEVICES']}."
|
||||
)
|
||||
|
||||
@property
|
||||
def run_name(self):
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def __del__(self):
|
||||
self.wait()
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: str | List[str],
|
||||
count: int = 1,
|
||||
gpu: int = 0,
|
||||
env_vars: Optional[Dict] = None,
|
||||
):
|
||||
if env_vars is None:
|
||||
env_vars = {}
|
||||
if not isinstance(cmd, list):
|
||||
cmd = [cmd] * count
|
||||
offset = self._job_counter[job_name]
|
||||
for i in range(count):
|
||||
if gpu > 0:
|
||||
# Allocate GPUs in a round-robin manner
|
||||
visible_devices = []
|
||||
for _ in range(gpu):
|
||||
available_device_id = self._gpu_counter % len(self._cuda_devices)
|
||||
self._gpu_counter += 1
|
||||
visible_devices.append(available_device_id)
|
||||
env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(self._cuda_devices[j]) for j in visible_devices
|
||||
)
|
||||
c = (
|
||||
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items())
|
||||
+ " stdbuf -oL "
|
||||
+ cmd[i]
|
||||
)
|
||||
c = f"{c} | tee -a {self.log_path_of(job_name)}"
|
||||
logger.info("Starting local process with command: %s", c)
|
||||
process = subprocess.Popen(c, shell=isinstance(c, str))
|
||||
self._jobs[f"{job_name}/{offset + i}"] = process
|
||||
self._job_counter[job_name] += 1
|
||||
|
||||
def submit(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: str | List[str],
|
||||
gpu: int = 0,
|
||||
env_vars: Optional[Dict] = None,
|
||||
):
|
||||
self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars)
|
||||
|
||||
def stop(self, job_name, signal=None):
|
||||
assert any(k.startswith(job_name) for k in self._jobs)
|
||||
keys = [k for k, p in self._jobs.items() if k.startswith(job_name)]
|
||||
procs = [p for k, p in self._jobs.items() if k.startswith(job_name)]
|
||||
logger.info(
|
||||
f"Stopping local process with signal {signal if signal else 'SIGKILL'}, "
|
||||
f"pid: {[p.pid for p in procs]}"
|
||||
)
|
||||
for p in procs:
|
||||
terminate_process_and_children(p.pid, signal=signal)
|
||||
for p in procs:
|
||||
p.wait()
|
||||
for k, p in zip(keys, procs):
|
||||
self._jobs.pop(k)
|
||||
del p
|
||||
|
||||
def stop_all(self, signal=None):
|
||||
# signal argument is ignored in local stop_all
|
||||
for name in self._job_counter:
|
||||
self.stop(name, signal=signal)
|
||||
|
||||
def find(self, job_name):
|
||||
if job_name in self._jobs:
|
||||
return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost")
|
||||
else:
|
||||
return JobInfo(name=job_name, state=JobState.NOT_FOUND)
|
||||
|
||||
def find_all(self, job_name_regex=".*"):
|
||||
rs = []
|
||||
for name in self._jobs:
|
||||
if re.fullmatch(job_name_regex, name):
|
||||
rs.append(self.find(name))
|
||||
return rs
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout=None,
|
||||
check_status: Tuple[JobState, ...] = (
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
),
|
||||
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
|
||||
update=False,
|
||||
):
|
||||
deadline = None if timeout is None else time.time() + timeout
|
||||
logger.info(
|
||||
"Waiting for %d local running processes, pids: %s",
|
||||
len(self._jobs),
|
||||
" ".join(str(job.pid) for job in self._jobs.values()),
|
||||
)
|
||||
left = set(self._jobs.keys())
|
||||
num_jobs_left = len(left)
|
||||
|
||||
while len(left) > 0:
|
||||
to_remove = []
|
||||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
|
||||
)
|
||||
# update job states
|
||||
for job_name in list(left):
|
||||
job = self._jobs[job_name]
|
||||
pid = job.pid
|
||||
process = psutil.Process(pid)
|
||||
self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get(
|
||||
process.status(), JobState.NOT_FOUND
|
||||
)
|
||||
|
||||
for job_name in list(left):
|
||||
state = self._job_states[job_name]
|
||||
if state in check_status:
|
||||
raise JobException(
|
||||
run_name=self.run_name,
|
||||
worker_type=job_name.split("/")[0],
|
||||
host="local",
|
||||
reason=state,
|
||||
)
|
||||
if state in remove_status:
|
||||
logger.info(f"Job {job_name} is {state}.(Removed)")
|
||||
left.remove(job_name)
|
||||
to_remove.append(job_name)
|
||||
|
||||
if update:
|
||||
for k in to_remove:
|
||||
self._jobs.pop(k)
|
||||
worker_type = k.split("/")[0]
|
||||
assert worker_type in self._job_counter
|
||||
self._job_counter[worker_type] -= 1
|
||||
if self._job_counter[worker_type] <= 0:
|
||||
self._job_counter.pop(worker_type)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def main_local():
|
||||
cfg, _ = parse_cli_args(sys.argv[2:])
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
|
||||
)
|
||||
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
|
||||
|
||||
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
|
||||
|
||||
server_cmd = []
|
||||
server_addrs = []
|
||||
if alloc_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
base_seed = cfg.sglang.random_seed
|
||||
cfg.sglang = to_structured_cfg(cfg.sglang, SGLangConfig)
|
||||
ports = find_free_ports(alloc_mode.gen_dp_size * 2, port_range=(10000, 50000))
|
||||
host_ip = gethostip()
|
||||
host = "localhost" if not cfg.sglang.enable_metrics else host_ip
|
||||
for i in range(alloc_mode.gen_dp_size):
|
||||
cfg.sglang.random_seed = base_seed + i
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
cfg.sglang,
|
||||
host=host,
|
||||
tp_size=alloc_mode.gen_tp_size,
|
||||
base_gpu_id=0,
|
||||
port=ports[i * 2],
|
||||
dist_init_addr=f"localhost:{ports[i*2+1]}",
|
||||
)
|
||||
server_cmd.append(cmd)
|
||||
server_addrs.append(f"{host}:{ports[i * 2]}")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Launch inference servers.
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
cmd=server_cmd,
|
||||
count=alloc_mode.gen_dp_size,
|
||||
gpu=alloc_mode.gen_pp_size * alloc_mode.gen_tp_size,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
|
||||
)
|
||||
|
||||
# Launch trainer entrypoint
|
||||
if not cfg.server_only:
|
||||
launcher.submit(
|
||||
job_name="trainer",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
|
||||
gpu=alloc_mode.train_world_size,
|
||||
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(
|
||||
check_status=(
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
JobState.COMPLETED,
|
||||
),
|
||||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all("SIGTERM")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_local()
|
|
@ -5,7 +5,6 @@ import time
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
@ -15,52 +14,41 @@ from arealite.api.cli_args import (
|
|||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
||||
from arealite.utils import network
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_sglang_engine"
|
||||
TRIAL_NAME = "trial_0"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13887
|
||||
DIST_PORT = 15887
|
||||
PORT, DIST_PORT = network.find_free_ports(2)
|
||||
HOST = network.gethostip()
|
||||
# set a large timeout since we may need to download the model from hub
|
||||
RUN_SERVER_TIMEOUT = 180
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(mem_fraction_static=0.3),
|
||||
model_path=MODEL_PATH,
|
||||
sglang_config=SGLangConfig(
|
||||
skip_tokenizer_init=True,
|
||||
model_path=MODEL_PATH,
|
||||
mem_fraction_static=0.3,
|
||||
),
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
tp_size=1,
|
||||
base_gpu_id=0,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
served_model_name=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
# Launch process
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
cmd.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
|
@ -82,11 +70,12 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
text="hello! how are you today",
|
||||
input_ids=tokenizer.encode("hello! how are you today"),
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
|
@ -97,7 +86,6 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
|
@ -111,7 +99,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
max_concurrent_rollouts=2,
|
||||
consumer_batch_size=2,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -124,12 +112,13 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
result = engine.rollout([data] * 2, workflow=workflow)
|
||||
result = engine.rollout_batch([data] * 2, workflow=workflow)
|
||||
assert isinstance(result, TensorDict)
|
||||
bs = result.batch_size
|
||||
assert bs == torch.Size([2 * n_samples])
|
||||
|
@ -149,7 +138,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
consumer_batch_size=bs,
|
||||
max_head_offpolicyness=ofp,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -162,6 +151,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
|
@ -170,7 +160,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
engine.submit(data, workflow=workflow)
|
||||
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
time.sleep(5)
|
||||
assert engine.output_queue.qsize() == min(bs * 2, bs * (ofp + 1))
|
||||
|
||||
# Update model version
|
||||
|
@ -181,7 +171,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
for _ in range(bs * 2):
|
||||
engine.submit(data, workflow=workflow)
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
time.sleep(5)
|
||||
assert engine.output_queue.qsize() == min(bs * 4, bs * (ofp + 2))
|
||||
|
||||
# exit
|
||||
|
@ -222,8 +212,9 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
inf_engine = RemoteSGLangEngine(config)
|
||||
inf_engine.initialize(None, None)
|
||||
# test update weights
|
||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||
update_weight_meta = WeightUpdateMeta(
|
||||
|
@ -233,3 +224,4 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
engine.upload_weights(update_weight_meta)
|
||||
future.result()
|
||||
assert inf_engine.get_version() == 100
|
||||
inf_engine.destroy()
|
||||
|
|
|
@ -8,7 +8,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ def mock_padded_data():
|
|||
ans_len = int(ans_len)
|
||||
seq = dict(
|
||||
input_ids=torch.randint(0, VOCAB_SIZE, size=(prompt_len + ans_len,)),
|
||||
prompt_mask=torch.tensor([1] * prompt_len + [0] * ans_len),
|
||||
loss_mask=torch.tensor([0] * prompt_len + [1] * ans_len),
|
||||
logprobs=torch.randn(prompt_len + ans_len),
|
||||
position_ids=torch.arange(prompt_len + ans_len),
|
||||
)
|
||||
|
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
|||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
|
||||
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
|
||||
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
|
|
|
@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return {}
|
||||
return TensorDict()
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
lens = []
|
||||
|
@ -156,7 +156,7 @@ def concat_padded_tensors(
|
|||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
if "attention_mask" not in result:
|
||||
result["attention_mask"] = attn_mask
|
||||
return result
|
||||
return TensorDict(result, batch_size=[len(lens)])
|
||||
|
||||
|
||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||
|
@ -290,13 +290,13 @@ class MicroBatchList:
|
|||
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
|
||||
|
||||
|
||||
def split_packed_tensor_dict_into_mb_list(
|
||||
def split_padded_tensor_dict_into_mb_list(
|
||||
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
|
||||
) -> MicroBatchList:
|
||||
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
|
||||
"""Split a padded tensordict into micro-batches based on the attention mask.
|
||||
|
||||
Args:
|
||||
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
|
||||
data (TensorDict): Dictionary containing padded tensors.
|
||||
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
|
||||
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
|
||||
|
||||
|
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
MicroBatchList: A structure containing the split micro-batches and metadata.
|
||||
"""
|
||||
assert (
|
||||
"cu_seqlens" in data
|
||||
), "Input data must be packed and contain 'cu_seqlens' key."
|
||||
"attention_mask" in data
|
||||
), "Input data must be padded and contain 'attention_mask' key."
|
||||
if mb_spec.max_tokens_per_mb is None:
|
||||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
|
||||
)
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
bs = cu_seqlens.shape[0] - 1
|
||||
total_lens = int(cu_seqlens[-1])
|
||||
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
|
||||
bs = data["attention_mask"].shape[0]
|
||||
max_seqlen = data["attention_mask"].shape[1]
|
||||
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
|
||||
|
||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != total_lens:
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
else:
|
||||
to_split[key] = value
|
||||
|
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
splitted_lens = [
|
||||
[input_lens[i] for i in group_index] for group_index in group_indices
|
||||
]
|
||||
group_n_seqs = [len(x) for x in splitted_lens]
|
||||
group_lens = [sum(x) for x in splitted_lens]
|
||||
|
||||
forward_indices = datapack.flat2d(group_indices)
|
||||
|
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
def _split(tensor):
|
||||
"""Split and pad a tensor based on forward indices and lens."""
|
||||
# Unpack the sequence
|
||||
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
|
||||
unpacked = [tensor[i] for i in range(bs)]
|
||||
# Reorder according to forward indices
|
||||
reordered = reorder_list(unpacked, forward_indices)
|
||||
reordered = torch.cat(reordered)
|
||||
reordered = torch.stack(reordered)
|
||||
# Unpack again according to split lens
|
||||
splitted = unpack_sequence(reordered, lens=group_lens)
|
||||
splitted = []
|
||||
offset = 0
|
||||
for _n_seqs in group_n_seqs:
|
||||
splitted.append(reordered[offset : offset + _n_seqs])
|
||||
offset += _n_seqs
|
||||
return splitted
|
||||
|
||||
to_split = dict_map(to_split, lambda x: _split(x))
|
||||
|
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
# organize splitted micro batches
|
||||
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
|
||||
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
|
||||
max_seqlen = max(lens)
|
||||
lens = torch.tensor(lens, device="cuda")
|
||||
batch_cu_seqlens = torch.nn.functional.pad(
|
||||
lens.cumsum(0, dtype=torch.int), (1, 0)
|
||||
)
|
||||
results.append(
|
||||
TensorDict(
|
||||
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
|
||||
)
|
||||
)
|
||||
results.append(TensorDict(**mb, **not_to_split))
|
||||
return MicroBatchList(
|
||||
data=data,
|
||||
mbs=results,
|
||||
|
@ -433,7 +426,7 @@ def pad_mb_list(
|
|||
# NOTE: GPU page size is 2MB
|
||||
# Take hidden size 4096 with bf16 dtype as an example,
|
||||
# the batch size of a page is 256
|
||||
pad_to_length = (l + 255) // 256 * 256
|
||||
pad_to_length = (int(l) + 255) // 256 * 256
|
||||
padded_mb, pad_len = pad_packed_tensor_dict(
|
||||
mb, pad_to_length, pad_value=pad_value
|
||||
)
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]:
|
||||
"""Get current memory usage."""
|
||||
assert unit in ["GB", "MB", "KB"]
|
||||
divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024
|
||||
mem_allocated = torch.cuda.memory_allocated()
|
||||
mem_reserved = torch.cuda.memory_reserved()
|
||||
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||
mem_used = mem_total - mem_free
|
||||
mem_allocated = f"{mem_allocated / divisor:.{precision}f}"
|
||||
mem_reserved = f"{mem_reserved / divisor:.{precision}f}"
|
||||
mem_used = f"{mem_used / divisor:.{precision}f}"
|
||||
mem_total = f"{mem_total / divisor:.{precision}f}"
|
||||
return mem_allocated, mem_reserved, mem_used, mem_total
|
||||
|
||||
|
||||
# Adapted from verl
|
||||
def log_gpu_stats(head: str, rank: int = 0):
|
||||
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
|
||||
mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()
|
||||
message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}"
|
||||
logger.info(msg=message)
|
|
@ -1,13 +1,9 @@
|
|||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
from arealite.api.cli_args import EvaluatorConfig
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from realhf.base import timeutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
||||
|
@ -22,8 +18,7 @@ class Evaluator:
|
|||
|
||||
def evaluate(
|
||||
self,
|
||||
valid_dataloader: "StatefulDataLoader",
|
||||
evaluate_fn: Callable[["TensorDict"], Any],
|
||||
evaluate_fn: Callable,
|
||||
epoch: int,
|
||||
step: int,
|
||||
global_step: int,
|
||||
|
@ -32,5 +27,4 @@ class Evaluator:
|
|||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
for data in valid_dataloader:
|
||||
evaluate_fn(data)
|
||||
evaluate_fn()
|
||||
|
|
|
@ -1,8 +1,175 @@
|
|||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
def _gather_logprobs(
|
||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||
):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
@torch.compile
|
||||
def _gather_logprobs_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||
):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||
entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels, entropy
|
||||
|
||||
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
chunk_size: int = 1024,
|
||||
):
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
if batch_size <= chunk_size:
|
||||
return _gather_logprobs(logits, labels, temperature)
|
||||
|
||||
log_probs_labels_list = []
|
||||
|
||||
for i in range(0, batch_size, chunk_size):
|
||||
end_idx = min(i + chunk_size, batch_size)
|
||||
chunk_logits = logits[i:end_idx]
|
||||
chunk_labels = labels[i:end_idx]
|
||||
|
||||
chunk_log_probs = _gather_logprobs(chunk_logits, chunk_labels, temperature)
|
||||
|
||||
log_probs_labels_list.append(chunk_log_probs)
|
||||
|
||||
return torch.cat(log_probs_labels_list)
|
||||
|
||||
|
||||
def gather_logprobs_entropy(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
chunk_size: int = 1024,
|
||||
):
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
if batch_size <= chunk_size:
|
||||
return _gather_logprobs_entropy(logits, labels, temperature)
|
||||
|
||||
log_probs_labels_list = []
|
||||
entropy_list = []
|
||||
|
||||
for i in range(0, batch_size, chunk_size):
|
||||
end_idx = min(i + chunk_size, batch_size)
|
||||
chunk_logits = logits[i:end_idx]
|
||||
chunk_labels = labels[i:end_idx]
|
||||
|
||||
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
|
||||
chunk_logits, chunk_labels, temperature
|
||||
)
|
||||
|
||||
log_probs_labels_list.append(chunk_log_probs)
|
||||
entropy_list.append(chunk_entropy)
|
||||
|
||||
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_normalization(
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
dim=None,
|
||||
unbiased=False,
|
||||
eps=1e-5,
|
||||
high_precision=True,
|
||||
all_reduce=True,
|
||||
reduce_group=None,
|
||||
):
|
||||
dtype = torch.float64 if high_precision else torch.float32
|
||||
x = x.to(dtype)
|
||||
if dim is None:
|
||||
dim = tuple(range(len(x.shape)))
|
||||
if mask is None:
|
||||
factor = torch.tensor(
|
||||
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
|
||||
)
|
||||
else:
|
||||
mask = mask.to(dtype)
|
||||
x = x * mask
|
||||
factor = mask.sum(dim, keepdim=True)
|
||||
x_sum = x.sum(dim=dim, keepdim=True)
|
||||
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
|
||||
if dist.is_initialized() and all_reduce:
|
||||
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(
|
||||
x_sum_sq,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=reduce_group,
|
||||
)
|
||||
mean = x_sum / factor
|
||||
meansq = x_sum_sq / factor
|
||||
var = meansq - mean**2
|
||||
if unbiased:
|
||||
var *= factor / (factor - 1)
|
||||
return ((x - mean) / (var.sqrt() + eps)).float()
|
||||
|
||||
|
||||
def ppo_actor_loss_fn(
|
||||
logprobs: torch.Tensor,
|
||||
old_logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
||||
behav_imp_weight_cap: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
denorm_logprobs = (
|
||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
)
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
pg_loss2 = -advantages * clipped_ratio
|
||||
clip_mask = pg_loss1.detach() < pg_loss2.detach()
|
||||
pg_loss = torch.max(pg_loss1, pg_loss2)
|
||||
if c_clip is not None:
|
||||
assert c_clip > 1.0, c_clip
|
||||
pg_loss3 = torch.sign(advantages) * c_clip * advantages
|
||||
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
|
||||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
logging_loss = pg_loss.detach()
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
dual_clip_mask.logical_and_(loss_mask)
|
||||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
if proximal_logprobs is not None:
|
||||
stat["behave_imp_weight"] = behav_imp_weight
|
||||
stat["behave_approx_kl"] = behav_kl
|
||||
stat["behave_mask"] = behav_mask
|
||||
return pg_loss, stat
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
DEFAULT_RETRIES = 1
|
||||
DEFAULT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
async def arequest_with_retry(
|
||||
addr: str,
|
||||
endpoint: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
method: str = "POST",
|
||||
max_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
retry_delay: float = 1.0,
|
||||
) -> aiohttp.ClientResponse:
|
||||
timeout = timeout or DEFAULT_REQUEST_TIMEOUT
|
||||
last_exception = None
|
||||
max_retries = max_retries or DEFAULT_RETRIES
|
||||
base_url = f"http://{addr}"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
response = await session.get(url)
|
||||
elif method.upper() == "POST":
|
||||
response = await session.post(url, json=payload)
|
||||
elif method.upper() == "PUT":
|
||||
response = await session.put(url, json=payload)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await session.delete(url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientResponseError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}"
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0
|
|
@ -0,0 +1,100 @@
|
|||
import random
|
||||
import socket
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
def gethostname():
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def gethostip():
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def find_free_ports(
|
||||
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Find multiple free ports within a specified range.
|
||||
|
||||
Args:
|
||||
count: Number of free ports to find
|
||||
port_range: Tuple of (min_port, max_port) to search within
|
||||
exclude_ports: Set of ports to exclude from search
|
||||
|
||||
Returns:
|
||||
List of free port numbers
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to find requested number of free ports
|
||||
"""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
min_port, max_port = port_range
|
||||
free_ports = []
|
||||
attempted_ports = set()
|
||||
|
||||
# Calculate available port range
|
||||
available_range = max_port - min_port + 1 - len(exclude_ports)
|
||||
|
||||
if count > available_range:
|
||||
raise ValueError(
|
||||
f"Cannot find {count} ports in range {port_range}. "
|
||||
f"Only {available_range} ports available."
|
||||
)
|
||||
|
||||
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
|
||||
attempts = 0
|
||||
|
||||
while len(free_ports) < count and attempts < max_attempts:
|
||||
# Generate random port within range
|
||||
port = random.randint(min_port, max_port)
|
||||
|
||||
# Skip if port already attempted or excluded
|
||||
if port in attempted_ports or port in exclude_ports:
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempted_ports.add(port)
|
||||
|
||||
if is_port_free(port):
|
||||
free_ports.append(port)
|
||||
|
||||
attempts += 1
|
||||
|
||||
if len(free_ports) < count:
|
||||
raise ValueError(
|
||||
f"Could only find {len(free_ports)} free ports "
|
||||
f"out of {count} requested after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
return sorted(free_ports)
|
||||
|
||||
|
||||
def is_port_free(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is free by attempting to bind to it.
|
||||
|
||||
Args:
|
||||
port: Port number to check
|
||||
|
||||
Returns:
|
||||
True if port is free, False otherwise
|
||||
"""
|
||||
# Check TCP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
# Check UDP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
|
@ -21,6 +21,18 @@ class Saver:
|
|||
freq_sec=config.freq_secs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_root(
|
||||
config: SaverConfig,
|
||||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_path(
|
||||
config: SaverConfig,
|
||||
|
@ -30,8 +42,7 @@ class Saver:
|
|||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
Saver.get_save_checkpoint_root(config, name),
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
@ -51,7 +62,9 @@ class Saver:
|
|||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
|
||||
path = Saver.get_save_checkpoint_path(
|
||||
self.config, epoch, step, global_step, name
|
||||
)
|
||||
weight_format = "hf"
|
||||
with_optim = False
|
||||
if self.for_recover:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import getpass
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
@ -21,6 +21,8 @@ class StatsLogger:
|
|||
self.ft_spec = ft_spec
|
||||
self.init()
|
||||
|
||||
self._last_commit_step = 0
|
||||
|
||||
def init(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
@ -61,7 +63,7 @@ class StatsLogger:
|
|||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict | List[Dict]):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
self.info(
|
||||
|
@ -69,12 +71,17 @@ class StatsLogger:
|
|||
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
|
||||
)
|
||||
self.info("Stats:")
|
||||
self.print_stats(data)
|
||||
wandb.log(data, step=global_step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, global_step)
|
||||
if isinstance(data, Dict):
|
||||
data = [data]
|
||||
log_step = max(global_step, self._last_commit_step)
|
||||
for i, item in enumerate(data):
|
||||
self.info(f"Stats ({i+1}/{len(data)}):")
|
||||
self.print_stats(item)
|
||||
wandb.log(item, step=log_step + i)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in item.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
|
||||
self._last_commit_step = log_step + len(data) - 1
|
||||
|
||||
def print_stats(self, stats: Dict[str, float]):
|
||||
self.info("\n" + tabulate_stats(stats))
|
||||
|
|
|
@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
enable_thinking: bool,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.enable_thinking = enable_thinking
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
data["messages"], tokenize=False, add_generation_prompt=True
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
text=text,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
@ -37,13 +42,13 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0] * resp.input_len + resp.output_logprobs
|
||||
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
|
||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt=self.tokenizer.decode(input_ids),
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
|
@ -51,10 +56,10 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq)).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
experiment_name: gsm8k-grpo
|
||||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster:
|
||||
fileroot: /tmp/arealite/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
seed: 1
|
||||
total_train_epochs: 10
|
||||
tokenizer_path: ${actor.path}
|
||||
async_training: true
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: 256
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: false
|
||||
|
||||
gconfig:
|
||||
n_samples: 4
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 512
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: false
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-6
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 1
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 256
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 256
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
|
@ -16,7 +16,7 @@ model:
|
|||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
optimizer:
|
||||
|
@ -34,13 +34,11 @@ train_dataset:
|
|||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
|
||||
# Adapted from verl.
|
||||
def extract_solution(solution_str, method="strict") -> str | None:
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
final_answer = None
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if len(solutions) == 0:
|
||||
final_answer = None
|
||||
else:
|
||||
# take the last solution
|
||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
||||
elif method == "flexible":
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ["", "."]
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
from realhf.impl.dataset.math_parser import extract_answer
|
||||
|
||||
sol = extract_answer(completions, data_name="math")
|
||||
ans = extract_solution(solution_str=answer, method="strict")
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
return int(sol.strip() == ans.strip())
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", rank, world_size),
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.valid_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
eval_rollout = RemoteSGLangEngine(config.rollout)
|
||||
eval_rollout.initialize(None, ft_spec)
|
||||
# NOTE: set a large version such that eval does not have any offpolicyness control
|
||||
eval_rollout.set_version(int(1e12))
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(
|
||||
data_generator,
|
||||
train_dataloader,
|
||||
workflow=workflow,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("recompute logp")
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
log_gpu_stats("compute advantages")
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver),
|
||||
"update_weights",
|
||||
str(global_step),
|
||||
),
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
rollout.set_version(global_step + 1)
|
||||
dist.barrier()
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"):
|
||||
|
||||
def evaluate_fn():
|
||||
rollout.pause()
|
||||
cnt = 0
|
||||
for data in valid_dataloader:
|
||||
for item in data:
|
||||
eval_rollout.submit(item, workflow)
|
||||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"].float().to(actor.device)
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
rewards.shape[0],
|
||||
device=rewards.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
)
|
||||
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
|
||||
rollout.resume()
|
||||
|
||||
evaluator.evaluate(
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
|
||||
logger.close()
|
||||
eval_rollout.destroy()
|
||||
rollout.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
actor.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
|
@ -22,10 +22,8 @@ def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
|||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
prompt_mask = [1] * len(prompt_token) + [0] * (
|
||||
len(seq_token) - len(prompt_token)
|
||||
)
|
||||
return {"input_ids": seq_token, "prompt_mask": prompt_mask}
|
||||
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
|
||||
return {"input_ids": seq_token, "loss_mask": loss_mask}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
@ -95,22 +93,31 @@ def main_sft():
|
|||
with stats_tracker.record_timing("save"):
|
||||
saver.save(engine, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
|
||||
with stats_tracker.record_timing("eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
def evaluate_fn():
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
for data in valid_dataloader:
|
||||
engine.evaluate_lm(data)
|
||||
|
||||
evaluator.evaluate(
|
||||
valid_dataloader,
|
||||
engine.evaluate_lm,
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats_tracker.export())
|
||||
logger.commit(
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
stats_tracker.export(reduce_group=engine.parallelism_group),
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
engine.destroy()
|
||||
logger.close()
|
||||
engine.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -41,12 +41,12 @@ class JobException(Exception):
|
|||
class JobInfo:
|
||||
name: str
|
||||
state: JobState
|
||||
host: str = (
|
||||
host: Optional[str] = (
|
||||
None # The host on which the job is/was running. None if the job had not run.
|
||||
)
|
||||
submit_time: str = None
|
||||
start_time: str = None
|
||||
slurm_id: str = None # Slurm only. The Slurm id of the job.
|
||||
submit_time: Optional[str] = None
|
||||
start_time: Optional[str] = None
|
||||
slurm_id: Optional[int] = None # Slurm only. The Slurm id of the job.
|
||||
|
||||
|
||||
class SchedulerClient:
|
||||
|
|
|
@ -40,7 +40,7 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
|||
)
|
||||
|
||||
|
||||
def apply_sglang_path():
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
|
@ -75,7 +75,7 @@ def launch_server_cmd(command: str, port: int = 30000):
|
|||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_path()
|
||||
apply_sglang_patch()
|
||||
assert port is not None
|
||||
full_command = f"{command} --port {port}"
|
||||
process = execute_shell_command(full_command)
|
||||
|
|
Loading…
Reference in New Issue