mirror of https://github.com/inclusionAI/AReaL
merge ppo
This commit is contained in:
commit
57ce1213ae
|
@ -127,7 +127,7 @@ class TrainEngineConfig:
|
|||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
|
||||
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
|
||||
optimizer: Optional[OptimizerConfig] = field(
|
||||
default=None, metadata={"help": "Optimizer configuration"}
|
||||
)
|
||||
|
@ -135,12 +135,95 @@ class TrainEngineConfig:
|
|||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOActorConfig(TrainEngineConfig):
|
||||
# Core PPO/GRPO Parameters
|
||||
group_size: int = field(
|
||||
default=1, metadata={"help": "Number of sequences in each group"}
|
||||
)
|
||||
group_adv_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize advantages within each prompt group rather than globally"
|
||||
},
|
||||
)
|
||||
ppo_n_minibatches: int = field(
|
||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||
)
|
||||
eps_clip: float = field(
|
||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||
)
|
||||
c_clip: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "Temperature during generation."}
|
||||
)
|
||||
# Reward
|
||||
group_reward_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||
},
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||
)
|
||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||
reward_clip: float = field(
|
||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||
)
|
||||
mask_no_eos_with_zero: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||
},
|
||||
)
|
||||
|
||||
# Advantage Estimation
|
||||
discount: float = field(
|
||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||
)
|
||||
gae_lambda: float = field(
|
||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||
|
||||
# Asynchronous RL
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
https://github.com/sgl-project/sglang for detailed documentation.
|
||||
"""
|
||||
|
||||
model_path: str = ""
|
||||
random_seed: int = 1
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
|
@ -195,38 +278,27 @@ class SGLangConfig:
|
|||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path: str,
|
||||
tp_size: int,
|
||||
base_gpu_id: int,
|
||||
server_port: Optional[int] = None,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
host,
|
||||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
) -> str:
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
if server_port is not None:
|
||||
args["port"] = server_port
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host = "localhost"
|
||||
|
||||
args = dict(
|
||||
host=host,
|
||||
model_path=model_path,
|
||||
# seed
|
||||
seed=seed if seed is not None else seeding.get_seed(),
|
||||
port=port,
|
||||
# Model and tokenizer
|
||||
tokenizer_path=model_path,
|
||||
tokenizer_path=sglang_config.model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
served_model_name=served_model_name,
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
# Other runtime options
|
||||
tp_size=tp_size,
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
|
@ -266,8 +338,8 @@ class SGLangConfig:
|
|||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
max_concurrent_rollouts: None | int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -290,17 +362,13 @@ class InferenceEngineConfig:
|
|||
"the request will not be accepted.",
|
||||
},
|
||||
)
|
||||
# Used by remote inference engines.
|
||||
server_addrs: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of server addresses for inference."},
|
||||
)
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
)
|
||||
setup_timeout: float = field(default=90.0)
|
||||
request_timeout: float = field(
|
||||
default=30.0, metadata={"help": "Timeout for HTTP requests."}
|
||||
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
||||
)
|
||||
request_retries: int = field(
|
||||
default=3, metadata={"help": "Number of retries for failed requests."}
|
||||
|
@ -541,6 +609,9 @@ class BaseExperimentConfig:
|
|||
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
|
||||
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
|
||||
|
||||
server_only: bool = False
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(BaseExperimentConfig):
|
||||
|
@ -549,15 +620,16 @@ class SFTConfig(BaseExperimentConfig):
|
|||
|
||||
@dataclass
|
||||
class GRPOConfig(BaseExperimentConfig):
|
||||
actor: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
ref: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
rollout: InferenceEngineConfig = field(
|
||||
default_factory=InferenceEngineConfig
|
||||
async_training: bool = field(default=True)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
|
||||
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
|
||||
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
|
||||
def parse_cli_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
|
@ -568,19 +640,26 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
|
|||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
# hydra only recognize relative paths
|
||||
relpath = Path(
|
||||
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
|
||||
)
|
||||
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
overrides=overrides,
|
||||
)
|
||||
return cfg, config_file
|
||||
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
# Merge with the default configuration.
|
||||
# The yaml and commandline can omit some default values defined in python dataclasses.
|
||||
default_cfg = OmegaConf.structured(config_cls)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
cfg, config_file = parse_cli_args(argv)
|
||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ class TrainEngine(abc.ABC):
|
|||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -87,7 +87,7 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -97,7 +97,7 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
@ -127,10 +127,23 @@ class InferenceEngine(abc.ABC):
|
|||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: float) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pause(self):
|
||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def resume(self):
|
||||
"""Resume request submission for async rollout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
|
|
|
@ -32,7 +32,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
|
@ -95,7 +95,7 @@ class FSDPEngine(TrainEngine):
|
|||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
|
@ -323,11 +323,8 @@ class FSDPEngine(TrainEngine):
|
|||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
self.config.mb_spec,
|
||||
)
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
|
|
|
@ -90,7 +90,7 @@ class HFEngine(TrainEngine):
|
|||
)
|
||||
torch.cuda.set_device("cuda:0")
|
||||
|
||||
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.engine_config.path,
|
||||
trust_remote_code=True,
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
import functools
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.data import pack_tensor_dict, split_padded_tensor_dict_into_mb_list
|
||||
from arealite.utils.functional import (
|
||||
calc_entropy,
|
||||
gather_logprobs,
|
||||
masked_normalization,
|
||||
ppo_actor_loss_fn,
|
||||
)
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
temperature: float,
|
||||
eps_clip: float,
|
||||
c_clip: float | None,
|
||||
behav_imp_weight_cap: float | None,
|
||||
):
|
||||
"""Loss function for actor step, all inputs should be splitted into
|
||||
pipeline micro batches, returns loss and logging stats."""
|
||||
input_ids = input_data["input_ids"]
|
||||
cu_seqlens = input_data["cu_seqlens"]
|
||||
old_logp = input_data["logprobs"]
|
||||
advantages = input_data["advantages"]
|
||||
loss_mask = input_data["loss_mask"]
|
||||
prox_logp = input_data.get("prox_logp", None)
|
||||
|
||||
logits = logits.float()
|
||||
logits /= temperature
|
||||
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
|
||||
loss, stat = ppo_actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
|
||||
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
|
||||
|
||||
# Log training statistics
|
||||
stats_tracker.denominator(
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
clipped_tokens=stat["clip_mask"],
|
||||
dual_clipped_tokens=stat["dual_clip_mask"],
|
||||
)
|
||||
|
||||
stats_tracker.stat(
|
||||
importance_weight=stat["importance_weight"],
|
||||
approx_kl=stat["approx_kl"],
|
||||
new_logp=logprobs.detach(),
|
||||
old_logp=old_logp,
|
||||
entropy=entropy.float(),
|
||||
actor_loss=stat["loss"],
|
||||
clip_ratio=stat["clip_mask"].float(),
|
||||
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=stat["behave_imp_weight"],
|
||||
behave_approx_kl=stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
clip_mask = stat["clip_mask"]
|
||||
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
||||
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
||||
stats_tracker.stat(
|
||||
clipped_new_logp=clipped_new_logp,
|
||||
clipped_old_logp=clipped_old_logp,
|
||||
denominator="clipped_tokens",
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class PPOActor:
|
||||
|
||||
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
|
||||
self.config = config
|
||||
self.engine = engine
|
||||
|
||||
self.reward_bias = config.reward_bias
|
||||
self.reward_scaling = config.reward_scaling
|
||||
self.reward_clip = config.reward_clip
|
||||
|
||||
self.group_reward_norm = config.group_reward_norm
|
||||
self.group_adv_norm = config.group_adv_norm
|
||||
self.group_size = config.group_size
|
||||
|
||||
self.kl_ctl = config.kl_ctl
|
||||
|
||||
self.adv_norm = config.adv_norm
|
||||
self.discount = config.discount
|
||||
self.gae_lambda = config.gae_lambda
|
||||
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero
|
||||
|
||||
self.temperature = config.temperature
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logp(
|
||||
self,
|
||||
data: TensorDict,
|
||||
temperature: Optional[float] = None,
|
||||
) -> torch.Tensor | None:
|
||||
|
||||
def calc_logprobs(logits, input_data):
|
||||
logits = logits.float()
|
||||
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
|
||||
if temperature is not None:
|
||||
logits /= temperature
|
||||
logprobs = gather_logprobs(logits, labels)
|
||||
return logprobs
|
||||
|
||||
return self.engine.forward(
|
||||
input_=data,
|
||||
post_hook=calc_logprobs,
|
||||
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
|
||||
)
|
||||
|
||||
def compute_advantages(self, data: TensorDict) -> None:
|
||||
bs = data["input_ids"].shape[0]
|
||||
max_seqlen = data["input_ids"].shape[1]
|
||||
batch_indices = torch.arange(
|
||||
bs, device=data["input_ids"].device, dtype=torch.long
|
||||
)
|
||||
|
||||
# Compute rewards using the reward function in synchronous RLVR pipeline.
|
||||
reward_score = data["rewards"]
|
||||
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
|
||||
reward_score = torch.clip(
|
||||
reward_score, max=self.reward_clip, min=-self.reward_clip
|
||||
)
|
||||
if self.group_reward_norm:
|
||||
for i in range(bs // self.group_size):
|
||||
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
||||
r = reward_score[s]
|
||||
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
|
||||
|
||||
loss_mask = data["prompt_mask"].logical_not().float()
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1)
|
||||
# Apply the mask to log probabilities.
|
||||
old_logp = data["logprobs"]
|
||||
ref_logp = data["ref_logp"]
|
||||
ref_logp *= loss_mask
|
||||
old_logp *= loss_mask
|
||||
|
||||
# Compute KL-regularized rewards.
|
||||
attn_mask = data["attention_mask"]
|
||||
seqlens = attn_mask.sum(-1).long()
|
||||
seq_no_eos_mask = seqlens == attn_mask.shape[1]
|
||||
rewards = -self.kl_ctl * (old_logp - ref_logp)
|
||||
kl_rewards = rewards.clone()
|
||||
# KL rewards at the next token after eos is zero.
|
||||
rewards[batch_indices, seqlens - 1] = 0
|
||||
indices = torch.clip(seqlens - 2, min=0)
|
||||
if self.mask_no_eos_with_zero:
|
||||
rewards[batch_indices, indices] += torch.where(
|
||||
seq_no_eos_mask, 0, reward_score
|
||||
)
|
||||
else:
|
||||
rewards[batch_indices, indices] += reward_score
|
||||
|
||||
# Compute GAE.
|
||||
if "values" not in data:
|
||||
values = torch.zeros_like(rewards)
|
||||
else:
|
||||
values = data["values"]
|
||||
advantages_reversed = []
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(max_seqlen - 1)):
|
||||
nextvalues = values[:, t + 1]
|
||||
if t == max_seqlen - 2:
|
||||
nextvalues *= seq_no_eos_mask
|
||||
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
||||
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages_reversed.append(
|
||||
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
||||
)
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
# Optionally perform advantage normalization.
|
||||
if self.adv_norm:
|
||||
if self.group_adv_norm:
|
||||
adv_list = []
|
||||
for i in range(0, bs, self.group_size):
|
||||
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
||||
adv = advantages[s]
|
||||
m = loss_mask[s]
|
||||
adv_list.append(masked_normalization(adv, m, all_reduce=False))
|
||||
advantages = torch.cat(adv_list, 0)
|
||||
else:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
|
||||
# Store data in the dict.
|
||||
data["advantages"] = advantages
|
||||
data["kl_rewards"] = kl_rewards
|
||||
data["tot_rewards"] = rewards
|
||||
data["loss_mask"] = loss_mask
|
||||
|
||||
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
|
||||
attn_mask = data["attention_mask"]
|
||||
loss_mask = data["loss_mask"]
|
||||
reward_score = data["rewards"]
|
||||
seqlens = attn_mask.sum(-1)
|
||||
|
||||
all_stats = []
|
||||
########## Logging code starts ##########
|
||||
result_denominators = {
|
||||
"correct_n_seqs": (reward_score > 0).bool(),
|
||||
"incorrect_n_seqs": (reward_score <= 0).bool(),
|
||||
}
|
||||
global_denominators = dict(
|
||||
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
||||
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
**result_denominators,
|
||||
)
|
||||
stats_tracker.denominator(**global_denominators)
|
||||
stats_tracker.stat(
|
||||
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
|
||||
)
|
||||
stats_tracker.stat(
|
||||
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
advantages=data["advantages"],
|
||||
kl_rewards=data["kl_rewards"],
|
||||
final_reward=data["tot_rewards"],
|
||||
)
|
||||
stats_tracker.stat(**stats, denominator="n_valid_tokens")
|
||||
|
||||
prompt_lens = []
|
||||
prompt_lens = data["prompt_mask"].sum(-1)
|
||||
seq_stats = dict(
|
||||
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
|
||||
task_reward=reward_score.float(),
|
||||
prompt_len=prompt_lens.float(),
|
||||
seq_len=seqlens.float(),
|
||||
)
|
||||
stats_tracker.stat(**seq_stats, denominator="n_seqs")
|
||||
scalars = dict(
|
||||
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
|
||||
eps_clip=self.config.eps_clip,
|
||||
use_prox_logp="prox_logp" in data,
|
||||
)
|
||||
if self.config.c_clip is not None:
|
||||
scalars["c_clip"] = self.config.c_clip
|
||||
scalars["use_dual_clip"] = 1
|
||||
else:
|
||||
scalars["use_dual_clip"] = 0
|
||||
if self.config.behav_imp_weight_cap is not None:
|
||||
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
|
||||
stats_tracker.scalar(**scalars)
|
||||
|
||||
global_stats = stats_tracker.export()
|
||||
for k in global_denominators:
|
||||
keys = list(global_stats.keys())
|
||||
for k2 in keys:
|
||||
if k2.endswith(k):
|
||||
global_stats.pop(k2)
|
||||
########## Logging code ends ##########
|
||||
|
||||
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
||||
data,
|
||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||
)
|
||||
for mb in mb_inputs.mbs:
|
||||
train_stat = self.engine.train_batch(
|
||||
mb,
|
||||
loss_fn=functools.partial(
|
||||
grpo_loss_fn,
|
||||
temperature=self.temperature,
|
||||
eps_clip=self.config.eps_clip,
|
||||
c_clip=self.config.c_clip,
|
||||
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
|
||||
),
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
stats_tracker.scalar(**train_stat)
|
||||
all_stats.append(stats_tracker.export())
|
||||
all_stats[0].update(global_stats)
|
||||
return all_stats
|
||||
|
||||
|
||||
class FSDPPPOActor(FSDPEngine):
|
||||
|
||||
def __init__(self, config: PPOActorConfig):
|
||||
super().__init__(config)
|
||||
self.actor = PPOActor(config, self)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
|
||||
return self.actor.compute_logp(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_advantages(self, *args, **kwargs) -> None:
|
||||
self.actor.compute_advantages(*args, **kwargs)
|
||||
|
||||
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
|
||||
return self.actor.ppo_update(*args, **kwargs)
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
@ -7,17 +8,20 @@ from queue import Empty, Full, Queue
|
|||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -46,22 +50,48 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
# Maintain the addresses for the recent 128 requests
|
||||
self.rid_queue = []
|
||||
|
||||
self.addresses = config.server_addrs
|
||||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
for addr in self.addresses:
|
||||
self._wait_for_server(addr)
|
||||
|
||||
self.server_idx = 0
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 10
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||
self.input_queue = Queue(maxsize=qsize)
|
||||
self.output_queue = Queue(maxsize=qsize)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
self.paused = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
self._version = 0
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
|
||||
def _wait_for_server(self, address):
|
||||
base_url = f"http://{address}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < self.config.setup_timeout:
|
||||
if self.check_health(base_url):
|
||||
return
|
||||
time.sleep(1)
|
||||
raise RuntimeError("server launch failed")
|
||||
|
||||
def check_health(self, base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
|
@ -119,13 +149,17 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
expected_version = sample_cnt // self.config.consumer_batch_size
|
||||
|
||||
consumer_bs = self.config.consumer_batch_size
|
||||
if dist.is_initialized():
|
||||
consumer_bs //= dist.get_world_size()
|
||||
expected_version = sample_cnt // consumer_bs
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({consumer_bs}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
@ -137,7 +171,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None:
|
||||
if can_rollout and data is not None and not self.paused.is_set():
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
)
|
||||
|
@ -191,6 +225,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
for task in rollout_tasks.values():
|
||||
|
@ -236,8 +272,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=30,
|
||||
sock_read=timeout,
|
||||
sock_connect=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
|
@ -252,7 +287,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
return await response.json()
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
|
@ -324,7 +359,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# loop until the generation is complete
|
||||
response = await self.arequest_with_retry(
|
||||
result = await self.arequest_with_retry(
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
|
@ -332,7 +367,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=self.config.request_timeout,
|
||||
target_addr=server_addr,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
# Parse response
|
||||
completions += result["text"]
|
||||
|
@ -400,7 +434,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
async def aupdate_weights_from_disk(self, addr, path: str):
|
||||
response = await self.arequest_with_retry(
|
||||
res = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
|
@ -408,7 +442,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=self.config.request_timeout,
|
||||
target_addr=addr,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
|
@ -422,9 +455,15 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
except Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
|
||||
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
timeout = timeout or float(7 * 24 * 3600)
|
||||
while (
|
||||
accepted < count
|
||||
and not self.exiting.is_set()
|
||||
|
@ -432,7 +471,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept(result):
|
||||
if should_accept is None or should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
|
@ -450,7 +489,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.result_cache[:count],
|
||||
self.result_cache[count:],
|
||||
)
|
||||
return TensorDict.cat(results, dim=0)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
|
@ -458,8 +497,10 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(
|
||||
count=len(data),
|
||||
timeout=self.config.request_timeout,
|
||||
should_accept=lambda x: True,
|
||||
)
|
||||
return self.wait(count=len(data))
|
||||
|
||||
def pause(self):
|
||||
self.paused.set()
|
||||
|
||||
def resume(self):
|
||||
self.paused.clear()
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
import re
|
||||
import signal as signal_module
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
|
||||
from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cfg
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import gpu_utils, logging
|
||||
from realhf.scheduler.client import (
|
||||
JobException,
|
||||
JobInfo,
|
||||
JobState,
|
||||
SchedulerClient,
|
||||
SchedulerError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("Local Scheduler")
|
||||
|
||||
JOB_STATE_TO_PROCESS_STATUS = {
|
||||
JobState.NOT_FOUND: [],
|
||||
JobState.PENDING: [psutil.STATUS_PARKED],
|
||||
JobState.RUNNING: [
|
||||
psutil.STATUS_RUNNING,
|
||||
psutil.STATUS_SLEEPING,
|
||||
psutil.STATUS_DISK_SLEEP,
|
||||
psutil.STATUS_TRACING_STOP,
|
||||
psutil.STATUS_WAKING,
|
||||
psutil.STATUS_WAITING,
|
||||
psutil.STATUS_LOCKED,
|
||||
psutil.STATUS_IDLE,
|
||||
],
|
||||
JobState.COMPLETED: [
|
||||
psutil.STATUS_DEAD,
|
||||
psutil.STATUS_STOPPED,
|
||||
psutil.STATUS_ZOMBIE,
|
||||
],
|
||||
JobState.FAILED: [],
|
||||
JobState.CANCELLED: [],
|
||||
}
|
||||
|
||||
PROCESS_STATUS_TO_JOB_STATE = {}
|
||||
for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items():
|
||||
for process_status in process_statuses:
|
||||
PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state
|
||||
|
||||
|
||||
def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None):
|
||||
if signal is None:
|
||||
signal = signal_module.SIGKILL
|
||||
if isinstance(signal, str):
|
||||
signal = getattr(signal_module, signal)
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
terminate_process_and_children(child.pid)
|
||||
parent.send_signal(signal)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
class LocalLauncher:
|
||||
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
self._jobs: Dict[str, subprocess.Popen] = {}
|
||||
self._job_counter: Dict[str, int] = defaultdict(int)
|
||||
self._job_states = {}
|
||||
|
||||
self._gpu_counter = 0
|
||||
self._cuda_devices: List[str] = os.environ.get(
|
||||
"CUDA_VISIBLE_DEVICES", ",".join(map(str, range(gpu_utils.gpu_count())))
|
||||
).split(",")
|
||||
if len(self._cuda_devices) < 1:
|
||||
raise RuntimeError(
|
||||
f"Local mode can only run when there is at least one GPU. "
|
||||
f"CUDA_VISIBLE_DEVICES is currently set to {os.environ['CUDA_VISIBLE_DEVICES']}."
|
||||
)
|
||||
|
||||
@property
|
||||
def run_name(self):
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def __del__(self):
|
||||
self.wait()
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: str | List[str],
|
||||
count: int = 1,
|
||||
gpu: int = 0,
|
||||
env_vars: Optional[Dict] = None,
|
||||
):
|
||||
if env_vars is None:
|
||||
env_vars = {}
|
||||
if not isinstance(cmd, list):
|
||||
cmd = [cmd] * count
|
||||
offset = self._job_counter[job_name]
|
||||
for i in range(count):
|
||||
if gpu > 0:
|
||||
# Allocate GPUs in a round-robin manner
|
||||
visible_devices = []
|
||||
for _ in range(gpu):
|
||||
available_device_id = self._gpu_counter % len(self._cuda_devices)
|
||||
self._gpu_counter += 1
|
||||
visible_devices.append(available_device_id)
|
||||
env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(self._cuda_devices[j]) for j in visible_devices
|
||||
)
|
||||
c = (
|
||||
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items())
|
||||
+ " stdbuf -oL "
|
||||
+ cmd[i]
|
||||
)
|
||||
c = f"{c} | tee -a {self.log_path_of(job_name)}"
|
||||
logger.info("Starting local process with command: %s", c)
|
||||
process = subprocess.Popen(c, shell=isinstance(c, str))
|
||||
self._jobs[f"{job_name}/{offset + i}"] = process
|
||||
self._job_counter[job_name] += 1
|
||||
|
||||
def submit(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: str | List[str],
|
||||
gpu: int = 0,
|
||||
env_vars: Optional[Dict] = None,
|
||||
):
|
||||
self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars)
|
||||
|
||||
def stop(self, job_name, signal=None):
|
||||
assert any(k.startswith(job_name) for k in self._jobs)
|
||||
keys = [k for k, p in self._jobs.items() if k.startswith(job_name)]
|
||||
procs = [p for k, p in self._jobs.items() if k.startswith(job_name)]
|
||||
logger.info(
|
||||
f"Stopping local process with signal {signal if signal else 'SIGKILL'}, "
|
||||
f"pid: {[p.pid for p in procs]}"
|
||||
)
|
||||
for p in procs:
|
||||
terminate_process_and_children(p.pid, signal=signal)
|
||||
for p in procs:
|
||||
p.wait()
|
||||
for k, p in zip(keys, procs):
|
||||
self._jobs.pop(k)
|
||||
del p
|
||||
|
||||
def stop_all(self, signal=None):
|
||||
# signal argument is ignored in local stop_all
|
||||
for name in self._job_counter:
|
||||
self.stop(name, signal=signal)
|
||||
|
||||
def find(self, job_name):
|
||||
if job_name in self._jobs:
|
||||
return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost")
|
||||
else:
|
||||
return JobInfo(name=job_name, state=JobState.NOT_FOUND)
|
||||
|
||||
def find_all(self, job_name_regex=".*"):
|
||||
rs = []
|
||||
for name in self._jobs:
|
||||
if re.fullmatch(job_name_regex, name):
|
||||
rs.append(self.find(name))
|
||||
return rs
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout=None,
|
||||
check_status: Tuple[JobState, ...] = (
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
),
|
||||
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
|
||||
update=False,
|
||||
):
|
||||
deadline = None if timeout is None else time.time() + timeout
|
||||
logger.info(
|
||||
"Waiting for %d local running processes, pids: %s",
|
||||
len(self._jobs),
|
||||
" ".join(str(job.pid) for job in self._jobs.values()),
|
||||
)
|
||||
left = set(self._jobs.keys())
|
||||
num_jobs_left = len(left)
|
||||
|
||||
while len(left) > 0:
|
||||
to_remove = []
|
||||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
|
||||
)
|
||||
# update job states
|
||||
for job_name in list(left):
|
||||
job = self._jobs[job_name]
|
||||
pid = job.pid
|
||||
process = psutil.Process(pid)
|
||||
self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get(
|
||||
process.status(), JobState.NOT_FOUND
|
||||
)
|
||||
|
||||
for job_name in list(left):
|
||||
state = self._job_states[job_name]
|
||||
if state in check_status:
|
||||
raise JobException(
|
||||
run_name=self.run_name,
|
||||
worker_type=job_name.split("/")[0],
|
||||
host="local",
|
||||
reason=state,
|
||||
)
|
||||
if state in remove_status:
|
||||
logger.info(f"Job {job_name} is {state}.(Removed)")
|
||||
left.remove(job_name)
|
||||
to_remove.append(job_name)
|
||||
|
||||
if update:
|
||||
for k in to_remove:
|
||||
self._jobs.pop(k)
|
||||
worker_type = k.split("/")[0]
|
||||
assert worker_type in self._job_counter
|
||||
self._job_counter[worker_type] -= 1
|
||||
if self._job_counter[worker_type] <= 0:
|
||||
self._job_counter.pop(worker_type)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def main_local():
|
||||
cfg, _ = parse_cli_args(sys.argv[2:])
|
||||
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
|
||||
|
||||
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
|
||||
|
||||
server_cmd = []
|
||||
server_addrs = []
|
||||
if alloc_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
base_seed = cfg.sglang.random_seed
|
||||
cfg.sglang = to_structured_cfg(cfg.sglang, SGLangConfig)
|
||||
ports = find_free_ports(alloc_mode.gen_dp_size * 2, port_range=(10000, 50000))
|
||||
host_ip = gethostip()
|
||||
host = "localhost" if not cfg.sglang.enable_metrics else host_ip
|
||||
for i in range(alloc_mode.gen_dp_size):
|
||||
cfg.sglang.random_seed = base_seed + i
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
cfg.sglang,
|
||||
host=host,
|
||||
tp_size=alloc_mode.gen_tp_size,
|
||||
base_gpu_id=0,
|
||||
port=ports[i * 2],
|
||||
dist_init_addr=f"localhost:{ports[i*2+1]}",
|
||||
)
|
||||
server_cmd.append(cmd)
|
||||
server_addrs.append(f"{host}:{ports[i * 2]}")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Launch inference servers.
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
cmd=server_cmd,
|
||||
count=alloc_mode.gen_dp_size,
|
||||
gpu=alloc_mode.gen_pp_size * alloc_mode.gen_tp_size,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
|
||||
)
|
||||
|
||||
# Launch trainer entrypoint
|
||||
if not cfg.server_only:
|
||||
launcher.submit(
|
||||
job_name="trainer",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} {' '.join(sys.argv[1:])}",
|
||||
gpu=alloc_mode.train_world_size,
|
||||
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(
|
||||
check_status=(
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
JobState.COMPLETED,
|
||||
),
|
||||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all("SIGTERM")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_local()
|
|
@ -15,62 +15,41 @@ from arealite.api.cli_args import (
|
|||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
||||
from arealite.utils import network
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_sglang_engine"
|
||||
TRIAL_NAME = "trial_0"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13887
|
||||
DIST_PORT = 15887
|
||||
PORT, DIST_PORT = network.find_free_ports(2)
|
||||
HOST = network.gethostip()
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(mem_fraction_static=0.3),
|
||||
model_path=MODEL_PATH,
|
||||
sglang_config=SGLangConfig(
|
||||
skip_tokenizer_init=False, model_path=MODEL_PATH, mem_fraction_static=0.3
|
||||
),
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
tp_size=1,
|
||||
base_gpu_id=0,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
served_model_name=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
# Launch process
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
cmd.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
base_url = f"http://{HOST}:{PORT}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < 90:
|
||||
if check_server_health(base_url):
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - tik > 90:
|
||||
raise RuntimeError("server launch failed")
|
||||
yield
|
||||
process.terminate()
|
||||
|
||||
|
@ -80,7 +59,7 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
|
@ -109,7 +88,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
max_concurrent_rollouts=2,
|
||||
consumer_batch_size=2,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -147,7 +126,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
consumer_batch_size=bs,
|
||||
max_head_offpolicyness=ofp,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -220,7 +199,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
inf_engine = RemoteSGLangEngine(config)
|
||||
# test update weights
|
||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||
|
|
|
@ -8,7 +8,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
|
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
|||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
|
||||
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
|
||||
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
|
|
|
@ -290,13 +290,13 @@ class MicroBatchList:
|
|||
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
|
||||
|
||||
|
||||
def split_packed_tensor_dict_into_mb_list(
|
||||
def split_padded_tensor_dict_into_mb_list(
|
||||
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
|
||||
) -> MicroBatchList:
|
||||
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
|
||||
"""Split a padded tensordict into micro-batches based on the attention mask.
|
||||
|
||||
Args:
|
||||
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
|
||||
data (TensorDict): Dictionary containing padded tensors.
|
||||
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
|
||||
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
|
||||
|
||||
|
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
MicroBatchList: A structure containing the split micro-batches and metadata.
|
||||
"""
|
||||
assert (
|
||||
"cu_seqlens" in data
|
||||
), "Input data must be packed and contain 'cu_seqlens' key."
|
||||
"attention_mask" in data
|
||||
), "Input data must be padded and contain 'attention_mask' key."
|
||||
if mb_spec.max_tokens_per_mb is None:
|
||||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
|
||||
)
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
bs = cu_seqlens.shape[0] - 1
|
||||
total_lens = int(cu_seqlens[-1])
|
||||
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
|
||||
bs = data["attention_mask"].shape[0]
|
||||
max_seqlen = data["attention_mask"].shape[1]
|
||||
input_lens = data["attention_mask"].sum(1).cpu().numpy()
|
||||
|
||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != total_lens:
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
else:
|
||||
to_split[key] = value
|
||||
|
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
splitted_lens = [
|
||||
[input_lens[i] for i in group_index] for group_index in group_indices
|
||||
]
|
||||
group_n_seqs = [len(x) for x in splitted_lens]
|
||||
group_lens = [sum(x) for x in splitted_lens]
|
||||
|
||||
forward_indices = datapack.flat2d(group_indices)
|
||||
|
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
def _split(tensor):
|
||||
"""Split and pad a tensor based on forward indices and lens."""
|
||||
# Unpack the sequence
|
||||
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
|
||||
unpacked = [tensor[i] for i in range(bs)]
|
||||
# Reorder according to forward indices
|
||||
reordered = reorder_list(unpacked, forward_indices)
|
||||
reordered = torch.cat(reordered)
|
||||
reordered = torch.stack(reordered)
|
||||
# Unpack again according to split lens
|
||||
splitted = unpack_sequence(reordered, lens=group_lens)
|
||||
splitted = []
|
||||
offset = 0
|
||||
for _n_seqs in group_n_seqs:
|
||||
splitted.append(reordered[offset : offset + _n_seqs])
|
||||
offset += _n_seqs
|
||||
return splitted
|
||||
|
||||
to_split = dict_map(to_split, lambda x: _split(x))
|
||||
|
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
# organize splitted micro batches
|
||||
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
|
||||
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
|
||||
max_seqlen = max(lens)
|
||||
lens = torch.tensor(lens, device="cuda")
|
||||
batch_cu_seqlens = torch.nn.functional.pad(
|
||||
lens.cumsum(0, dtype=torch.int), (1, 0)
|
||||
)
|
||||
results.append(
|
||||
TensorDict(
|
||||
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
|
||||
)
|
||||
)
|
||||
results.append(TensorDict(**mb, **not_to_split))
|
||||
return MicroBatchList(
|
||||
data=data,
|
||||
mbs=results,
|
||||
|
@ -433,7 +426,7 @@ def pad_mb_list(
|
|||
# NOTE: GPU page size is 2MB
|
||||
# Take hidden size 4096 with bf16 dtype as an example,
|
||||
# the batch size of a page is 256
|
||||
pad_to_length = (l + 255) // 256 * 256
|
||||
pad_to_length = (int(l) + 255) // 256 * 256
|
||||
padded_mb, pad_len = pad_packed_tensor_dict(
|
||||
mb, pad_to_length, pad_value=pad_value
|
||||
)
|
||||
|
|
|
@ -4,10 +4,6 @@ from arealite.api.cli_args import EvaluatorConfig
|
|||
from arealite.api.io_struct import FinetuneSpec
|
||||
from realhf.base import timeutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
||||
|
@ -22,8 +18,7 @@ class Evaluator:
|
|||
|
||||
def evaluate(
|
||||
self,
|
||||
valid_dataloader: "StatefulDataLoader",
|
||||
evaluate_fn: Callable[["TensorDict"], Any],
|
||||
evaluate_fn: Callable,
|
||||
epoch: int,
|
||||
step: int,
|
||||
global_step: int,
|
||||
|
@ -32,5 +27,4 @@ class Evaluator:
|
|||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
for data in valid_dataloader:
|
||||
evaluate_fn(data)
|
||||
evaluate_fn()
|
||||
|
|
|
@ -6,3 +6,115 @@ def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
|
|||
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def calc_entropy(logits, cu_seqlens):
|
||||
probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1)
|
||||
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_normalization(
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
dim=None,
|
||||
unbiased=False,
|
||||
eps=1e-5,
|
||||
high_precision=True,
|
||||
all_reduce=True,
|
||||
reduce_group=None,
|
||||
):
|
||||
dtype = torch.float64 if high_precision else torch.float32
|
||||
x = x.to(dtype)
|
||||
if dim is None:
|
||||
dim = tuple(range(len(x.shape)))
|
||||
if mask is None:
|
||||
factor = torch.tensor(
|
||||
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
|
||||
)
|
||||
else:
|
||||
mask = mask.to(dtype)
|
||||
x = x * mask
|
||||
factor = mask.sum(dim, keepdim=True)
|
||||
x_sum = x.sum(dim=dim, keepdim=True)
|
||||
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
|
||||
if dist.is_initialized() and all_reduce:
|
||||
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(
|
||||
x_sum_sq,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=reduce_group,
|
||||
)
|
||||
mean = x_sum / factor
|
||||
meansq = x_sum_sq / factor
|
||||
var = meansq - mean**2
|
||||
if unbiased:
|
||||
var *= factor / (factor - 1)
|
||||
return ((x - mean) / (var.sqrt() + eps)).float()
|
||||
|
||||
|
||||
def ppo_actor_loss_fn(
|
||||
logprobs: torch.Tensor,
|
||||
old_logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
||||
behav_imp_weight_cap: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
denorm_logprobs = (
|
||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
)
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
pg_loss2 = -advantages * clipped_ratio
|
||||
clip_mask = pg_loss1.detach() < pg_loss2.detach()
|
||||
pg_loss = torch.max(pg_loss1, pg_loss2)
|
||||
if c_clip is not None:
|
||||
assert c_clip > 1.0, c_clip
|
||||
pg_loss3 = torch.sign(advantages) * c_clip * advantages
|
||||
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
|
||||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
logging_loss = pg_loss.detach()
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
dual_clip_mask.logical_and_(loss_mask)
|
||||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
if proximal_logprobs is not None:
|
||||
stat["behave_imp_weight"] = behav_imp_weight
|
||||
stat["behave_approx_kl"] = behav_kl
|
||||
stat["behave_mask"] = behav_mask
|
||||
return pg_loss, stat
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import random
|
||||
import socket
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
def gethostname():
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def gethostip():
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def find_free_ports(
|
||||
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Find multiple free ports within a specified range.
|
||||
|
||||
Args:
|
||||
count: Number of free ports to find
|
||||
port_range: Tuple of (min_port, max_port) to search within
|
||||
exclude_ports: Set of ports to exclude from search
|
||||
|
||||
Returns:
|
||||
List of free port numbers
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to find requested number of free ports
|
||||
"""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
min_port, max_port = port_range
|
||||
free_ports = []
|
||||
attempted_ports = set()
|
||||
|
||||
# Calculate available port range
|
||||
available_range = max_port - min_port + 1 - len(exclude_ports)
|
||||
|
||||
if count > available_range:
|
||||
raise ValueError(
|
||||
f"Cannot find {count} ports in range {port_range}. "
|
||||
f"Only {available_range} ports available."
|
||||
)
|
||||
|
||||
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
|
||||
attempts = 0
|
||||
|
||||
while len(free_ports) < count and attempts < max_attempts:
|
||||
# Generate random port within range
|
||||
port = random.randint(min_port, max_port)
|
||||
|
||||
# Skip if port already attempted or excluded
|
||||
if port in attempted_ports or port in exclude_ports:
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempted_ports.add(port)
|
||||
|
||||
if is_port_free(port):
|
||||
free_ports.append(port)
|
||||
|
||||
attempts += 1
|
||||
|
||||
if len(free_ports) < count:
|
||||
raise ValueError(
|
||||
f"Could only find {len(free_ports)} free ports "
|
||||
f"out of {count} requested after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
return sorted(free_ports)
|
||||
|
||||
|
||||
def is_port_free(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is free by attempting to bind to it.
|
||||
|
||||
Args:
|
||||
port: Port number to check
|
||||
|
||||
Returns:
|
||||
True if port is free, False otherwise
|
||||
"""
|
||||
# Check TCP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
# Check UDP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
|
@ -1,7 +1,7 @@
|
|||
import getpass
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
@ -21,6 +21,8 @@ class StatsLogger:
|
|||
self.ft_spec = ft_spec
|
||||
self.init()
|
||||
|
||||
self._last_commit_step = 0
|
||||
|
||||
def init(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
@ -61,7 +63,7 @@ class StatsLogger:
|
|||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict | List[Dict]):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
self.info(
|
||||
|
@ -69,12 +71,17 @@ class StatsLogger:
|
|||
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
|
||||
)
|
||||
self.info("Stats:")
|
||||
self.print_stats(data)
|
||||
wandb.log(data, step=global_step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, global_step)
|
||||
if isinstance(data, Dict):
|
||||
data = [data]
|
||||
log_step = max(global_step, self._last_commit_step)
|
||||
for i, item in enumerate(data):
|
||||
self.info(f"Stats ({i+1}/{len(data)}):")
|
||||
self.print_stats(item)
|
||||
wandb.log(item, step=log_step + i)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in item.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
|
||||
self._last_commit_step = log_step + len(data) - 1
|
||||
|
||||
def print_stats(self, stats: Dict[str, float]):
|
||||
self.info("\n" + tabulate_stats(stats))
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
experiment_name: gsm8k-grpo
|
||||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster:
|
||||
fileroot: /tmp/arealite/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
seed: 1
|
||||
total_train_epochs: 1
|
||||
tokenizer_path: ${actor.path}
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: null
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 0
|
||||
|
||||
gconfig:
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 1024
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
dtype: float16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1e-5
|
||||
lr_scheduler_type: cosine
|
||||
gradient_clipping: 1.0
|
||||
backend: fsdp
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: false
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: 1
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
|
@ -16,7 +16,7 @@ model:
|
|||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
optimizer:
|
||||
|
@ -34,13 +34,11 @@ train_dataset:
|
|||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages, "method": "strict"}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
|
||||
# Adapted from verl.
|
||||
def extract_solution(solution_str, method="strict"):
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
final_answer = None
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if len(solutions) == 0:
|
||||
final_answer = None
|
||||
else:
|
||||
# take the last solution
|
||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
||||
elif method == "flexible":
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ["", "."]
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def gsm8k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, method, **kwargs
|
||||
):
|
||||
sol = extract_solution(solution_str=completions, method=method)
|
||||
if sol is None:
|
||||
return 0
|
||||
return int(sol == answer)
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", rank, world_size),
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.valid_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
eval_rollout = RemoteSGLangEngine(config.rollout)
|
||||
eval_rollout.initialize(None, ft_spec)
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=gsm8k_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(train_dataloader):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
batch = rollout.rollout(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"):
|
||||
|
||||
def evaluate_fn():
|
||||
rollout.pause()
|
||||
cnt = 0
|
||||
for data in valid_dataloader:
|
||||
for item in data:
|
||||
eval_rollout.submit(item, workflow)
|
||||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"]
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
rewards.shape[0],
|
||||
device=rewards.device,
|
||||
dtype=rewards.dtype,
|
||||
)
|
||||
)
|
||||
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
|
||||
rollout.resume()
|
||||
|
||||
evaluator.evaluate(
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
global_step += 1
|
||||
|
||||
actor.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
|
@ -95,12 +95,16 @@ def main_sft():
|
|||
with stats_tracker.record_timing("save"):
|
||||
saver.save(engine, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
|
||||
with stats_tracker.record_timing("eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
def evaluate_fn():
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
for data in valid_dataloader:
|
||||
engine.evaluate_lm(data)
|
||||
|
||||
evaluator.evaluate(
|
||||
valid_dataloader,
|
||||
engine.evaluate_lm,
|
||||
evaluate_fn,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
|
|
Loading…
Reference in New Issue