mirror of https://github.com/inclusionAI/AReaL
qwen2 grpo works
This commit is contained in:
parent
0cbddb8aba
commit
888751da38
|
@ -9,7 +9,6 @@ from hydra import initialize as hydra_init
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from arealite.utils.fs import get_user_tmp
|
||||
from realhf.api.cli_args import OptimizerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -81,6 +80,64 @@ class GenerationHyperparameters:
|
|||
|
||||
|
||||
# Train Engine Configs
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for model optimization during training.
|
||||
|
||||
Note:
|
||||
Set type to "empty" for models that won't be trained.
|
||||
"""
|
||||
|
||||
type: str = field(
|
||||
default="adam",
|
||||
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
||||
)
|
||||
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
||||
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
||||
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
||||
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
||||
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
|
||||
min_lr_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Minimum learning rate ratio after annealing",
|
||||
},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="constant",
|
||||
metadata={
|
||||
"help": "Learning rate scheduler type",
|
||||
"choices": ["linear", "cosine", "constant"],
|
||||
},
|
||||
)
|
||||
warmup_steps_proportion: float = field(
|
||||
default=0.001,
|
||||
metadata={
|
||||
"help": "Proportion of training steps for warmup",
|
||||
},
|
||||
)
|
||||
offload: bool = field(
|
||||
default=False, metadata={"help": "Enable optimizer state offloading"}
|
||||
)
|
||||
initial_loss_scale: float = field(
|
||||
default=2**32, metadata={"help": "Initial loss scaling factor"}
|
||||
)
|
||||
min_loss_scale: float = field(
|
||||
default=1.0, metadata={"help": "Minimum loss scaling factor"}
|
||||
)
|
||||
loss_scale_window: float = field(
|
||||
default=5, metadata={"help": "Window size for loss scaling adjustment"}
|
||||
)
|
||||
hysteresis: int = field(
|
||||
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
|
||||
)
|
||||
gradient_clipping: float = field(
|
||||
default=1.0, metadata={"help": "Gradient clipping threshold"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPWrapPolicy:
|
||||
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
|
||||
|
@ -284,7 +341,7 @@ class SGLangConfig:
|
|||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.base import pkg_version
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
|
@ -361,6 +418,7 @@ class InferenceEngineConfig:
|
|||
"the request will not be accepted.",
|
||||
},
|
||||
)
|
||||
enable_rollout_tracing: bool = field(default=False)
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
|
@ -663,8 +721,11 @@ def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig,
|
|||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
# Setup environment
|
||||
from realhf.base import constants, name_resolve
|
||||
from realhf.base import constants, name_resolve, names
|
||||
|
||||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
|
||||
)
|
||||
return cfg, str(config_file)
|
||||
|
|
|
@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
|
|||
@dataclass
|
||||
class LLMRequest:
|
||||
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text: Optional[str] = None
|
||||
input_ids: List[int] = field(default_factory=list)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
|
@ -28,7 +27,6 @@ class LLMRequest:
|
|||
@dataclass
|
||||
class LLMResponse:
|
||||
# outputs
|
||||
completions: str
|
||||
input_tokens: List[int] = field(default_factory=list)
|
||||
output_tokens: List[int] = field(default_factory=list)
|
||||
output_logprobs: List[float] = field(default_factory=list)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -101,13 +102,23 @@ class FSDPEngine(TrainEngine):
|
|||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
|
@ -129,23 +140,14 @@ class FSDPEngine(TrainEngine):
|
|||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
tik = time.perf_counter()
|
||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
||||
self.model = model
|
||||
|
||||
if not self.config.init_from_scratch:
|
||||
# Load model from a initial checkpoint path,
|
||||
# which should only be a huggingface checkpoint.
|
||||
load_meta = SaveLoadMeta(
|
||||
path=self.config.path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=None,
|
||||
base_model_path=self.config.path,
|
||||
)
|
||||
self.load(load_meta)
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
tik = time.perf_counter()
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
|
@ -189,6 +191,7 @@ class FSDPEngine(TrainEngine):
|
|||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
|
@ -300,7 +303,9 @@ class FSDPEngine(TrainEngine):
|
|||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
)
|
||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
|
|
|
@ -7,94 +7,16 @@ from tensordict import TensorDict
|
|||
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.data import pack_tensor_dict, split_padded_tensor_dict_into_mb_list
|
||||
from arealite.utils.data import split_padded_tensor_dict_into_mb_list
|
||||
from arealite.utils.functional import (
|
||||
calc_entropy,
|
||||
gather_logprobs,
|
||||
gather_logprobs_entropy,
|
||||
masked_normalization,
|
||||
ppo_actor_loss_fn,
|
||||
)
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
temperature: float,
|
||||
eps_clip: float,
|
||||
c_clip: float | None,
|
||||
behav_imp_weight_cap: float | None,
|
||||
):
|
||||
"""Loss function for actor step, all inputs should be splitted into
|
||||
pipeline micro batches, returns loss and logging stats."""
|
||||
input_ids = input_data["input_ids"]
|
||||
cu_seqlens = input_data["cu_seqlens"]
|
||||
old_logp = input_data["logprobs"]
|
||||
advantages = input_data["advantages"]
|
||||
loss_mask = input_data["loss_mask"]
|
||||
prox_logp = input_data.get("prox_logp", None)
|
||||
|
||||
logits = logits.float()
|
||||
logits /= temperature
|
||||
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
|
||||
loss, stat = ppo_actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
|
||||
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
|
||||
|
||||
# Log training statistics
|
||||
stats_tracker.denominator(
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
clipped_tokens=stat["clip_mask"],
|
||||
dual_clipped_tokens=stat["dual_clip_mask"],
|
||||
)
|
||||
|
||||
stats_tracker.stat(
|
||||
importance_weight=stat["importance_weight"],
|
||||
approx_kl=stat["approx_kl"],
|
||||
new_logp=logprobs.detach(),
|
||||
old_logp=old_logp,
|
||||
entropy=entropy.float(),
|
||||
actor_loss=stat["loss"],
|
||||
clip_ratio=stat["clip_mask"].float(),
|
||||
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=stat["behave_imp_weight"],
|
||||
behave_approx_kl=stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
clip_mask = stat["clip_mask"]
|
||||
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
||||
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
||||
stats_tracker.stat(
|
||||
clipped_new_logp=clipped_new_logp,
|
||||
clipped_old_logp=clipped_old_logp,
|
||||
denominator="clipped_tokens",
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class PPOActor:
|
||||
|
||||
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
|
||||
|
@ -126,11 +48,8 @@ class PPOActor:
|
|||
) -> torch.Tensor | None:
|
||||
|
||||
def calc_logprobs(logits, input_data):
|
||||
logits = logits.float()
|
||||
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
|
||||
if temperature is not None:
|
||||
logits /= temperature
|
||||
logprobs = gather_logprobs(logits, labels)
|
||||
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
|
||||
return logprobs
|
||||
|
||||
return self.engine.forward(
|
||||
|
@ -159,10 +78,13 @@ class PPOActor:
|
|||
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
|
||||
|
||||
loss_mask = data["prompt_mask"].logical_not().float()
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1)
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||
# Apply the mask to log probabilities.
|
||||
old_logp = data["logprobs"]
|
||||
ref_logp = data["ref_logp"]
|
||||
if not self.config.recompute_logprob:
|
||||
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
|
||||
else:
|
||||
old_logp = data["logprobs"]
|
||||
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
|
||||
ref_logp *= loss_mask
|
||||
old_logp *= loss_mask
|
||||
|
||||
|
@ -219,6 +141,8 @@ class PPOActor:
|
|||
data["kl_rewards"] = kl_rewards
|
||||
data["tot_rewards"] = rewards
|
||||
data["loss_mask"] = loss_mask
|
||||
# because we have rolled old_logp by -1
|
||||
data["logprobs"] = old_logp
|
||||
|
||||
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
|
||||
attn_mask = data["attention_mask"]
|
||||
|
@ -284,6 +208,8 @@ class PPOActor:
|
|||
global_stats.pop(k2)
|
||||
########## Logging code ends ##########
|
||||
|
||||
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
||||
data.pop(key, None)
|
||||
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
||||
data,
|
||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||
|
@ -322,3 +248,79 @@ class FSDPPPOActor(FSDPEngine):
|
|||
|
||||
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
|
||||
return self.actor.ppo_update(*args, **kwargs)
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
temperature: float,
|
||||
eps_clip: float,
|
||||
c_clip: float | None,
|
||||
behav_imp_weight_cap: float | None,
|
||||
):
|
||||
"""Loss function for actor step, all inputs should be splitted into
|
||||
pipeline micro batches, returns loss and logging stats."""
|
||||
input_ids = input_data["input_ids"]
|
||||
old_logp = input_data["logprobs"]
|
||||
advantages = input_data["advantages"]
|
||||
loss_mask = input_data["loss_mask"].bool()
|
||||
prox_logp = input_data.get("prox_logp", None)
|
||||
|
||||
logprobs, entropy = gather_logprobs_entropy(
|
||||
logits, torch.roll(input_ids, shifts=-1, dims=-1), temperature
|
||||
)
|
||||
entropy = entropy.detach()
|
||||
loss, stat = ppo_actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
|
||||
# Log training statistics
|
||||
stats_tracker.denominator(
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
clipped_tokens=stat["clip_mask"],
|
||||
dual_clipped_tokens=stat["dual_clip_mask"],
|
||||
)
|
||||
|
||||
stats_tracker.stat(
|
||||
importance_weight=stat["importance_weight"],
|
||||
approx_kl=stat["approx_kl"],
|
||||
new_logp=logprobs.detach(),
|
||||
old_logp=old_logp,
|
||||
entropy=entropy.float(),
|
||||
actor_loss=stat["loss"],
|
||||
clip_ratio=stat["clip_mask"].float(),
|
||||
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=stat["behave_imp_weight"],
|
||||
behave_approx_kl=stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
clip_mask = stat["clip_mask"]
|
||||
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
||||
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
||||
stats_tracker.stat(
|
||||
clipped_new_logp=clipped_new_logp,
|
||||
clipped_old_logp=clipped_old_logp,
|
||||
denominator="clipped_tokens",
|
||||
)
|
||||
return loss
|
||||
|
|
|
@ -4,6 +4,7 @@ import threading
|
|||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
|
@ -34,7 +35,7 @@ if pkg_version.is_available("sglang"):
|
|||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.1
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
|
@ -53,8 +54,10 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
logger.info("Waiting for server ready...")
|
||||
for addr in self.addresses:
|
||||
self._wait_for_server(addr)
|
||||
logger.info("Servers are all ready!")
|
||||
|
||||
self.server_idx = 0
|
||||
|
||||
|
@ -115,7 +118,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
pending_data = []
|
||||
|
||||
rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
rid = 0
|
||||
|
@ -123,12 +126,14 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Load next data from controller
|
||||
if data is None:
|
||||
while True:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.info(f"Get data from puller: {data}")
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
pending_data.append(data)
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
break
|
||||
|
||||
# Check capacity
|
||||
if dist.is_initialized():
|
||||
|
@ -136,59 +141,36 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
else:
|
||||
world_size = 1
|
||||
|
||||
cannot_rollout_reason = []
|
||||
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
|
||||
can_rollout = len(rollout_tasks) < capacity
|
||||
if not can_rollout:
|
||||
cannot_rollout_reason.append(
|
||||
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
|
||||
)
|
||||
|
||||
max_concurrent_rollouts = max(
|
||||
1, self.config.max_concurrent_rollouts // world_size
|
||||
)
|
||||
capacity = max_concurrent_rollouts - len(rollout_tasks)
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
|
||||
consumer_bs = self.config.consumer_batch_size
|
||||
if dist.is_initialized():
|
||||
consumer_bs //= dist.get_world_size()
|
||||
expected_version = sample_cnt // consumer_bs
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({consumer_bs}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
||||
if not can_rollout:
|
||||
logger.debug(
|
||||
f"Cannot submit new rollouts. "
|
||||
+ "\n".join(cannot_rollout_reason)
|
||||
)
|
||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None and not self.paused.is_set():
|
||||
while capacity > 0 and pending_data and not self.paused.is_set():
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
|
||||
)
|
||||
rollout_tasks[str(rid)] = task
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
data = None
|
||||
|
||||
# Wait for rollout completion
|
||||
tasks = list(rollout_tasks.values())
|
||||
|
@ -199,8 +181,10 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if not done:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
|
@ -219,12 +203,13 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
|
@ -323,15 +308,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"text": req.text,
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
if req.text:
|
||||
payload["text"] = req.text
|
||||
else:
|
||||
payload["input_ids"] = req.input_ids
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
|
@ -369,7 +350,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
|
||||
# Parse response
|
||||
completions += result["text"]
|
||||
meta_info = result["meta_info"]
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
@ -384,12 +364,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += result["text"]
|
||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
completions=completions,
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
|
@ -410,10 +389,10 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name, self.config.trial_name, meta.model_version
|
||||
)
|
||||
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = time.time_ns()
|
||||
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = datetime.now().timestamp()
|
||||
logger.info(
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
||||
)
|
||||
try:
|
||||
jobs = [
|
||||
|
@ -427,7 +406,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
|
||||
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
else:
|
||||
|
@ -478,7 +457,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
with self.lock:
|
||||
self.rollout_stat.accepted -= 1
|
||||
except Empty:
|
||||
time.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
pass
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
import re
|
||||
|
@ -15,13 +14,7 @@ from arealite.api.cli_args import SGLangConfig, parse_cli_args, to_structured_cf
|
|||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import gpu_utils, logging
|
||||
from realhf.scheduler.client import (
|
||||
JobException,
|
||||
JobInfo,
|
||||
JobState,
|
||||
SchedulerClient,
|
||||
SchedulerError,
|
||||
)
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Local Scheduler")
|
||||
|
||||
|
@ -286,7 +279,7 @@ def main_local():
|
|||
if not cfg.server_only:
|
||||
launcher.submit(
|
||||
job_name="trainer",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} {' '.join(sys.argv[1:])}",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
|
||||
gpu=alloc_mode.train_world_size,
|
||||
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
||||
)
|
||||
|
|
|
@ -5,7 +5,6 @@ import time
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
@ -34,7 +33,9 @@ def sglang_server():
|
|||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(
|
||||
skip_tokenizer_init=False, model_path=MODEL_PATH, mem_fraction_static=0.3
|
||||
skip_tokenizer_init=True,
|
||||
model_path=MODEL_PATH,
|
||||
mem_fraction_static=0.3,
|
||||
),
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
|
@ -59,11 +60,12 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
text="hello! how are you today",
|
||||
input_ids=tokenizer.encode("hello! how are you today"),
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
|
@ -74,7 +76,6 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
|
@ -101,6 +102,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
data = {
|
||||
|
@ -139,6 +141,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
from arealite.api.cli_args import EvaluatorConfig
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
|
|
|
@ -2,12 +2,24 @@ import torch
|
|||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||
):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||
):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||
entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels, entropy
|
||||
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
@ -15,14 +27,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def calc_entropy(logits, cu_seqlens):
|
||||
probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1)
|
||||
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_normalization(
|
||||
x: torch.Tensor,
|
||||
|
|
|
@ -51,7 +51,9 @@ class Saver:
|
|||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
|
||||
path = Saver.get_save_checkpoint_path(
|
||||
self.config, epoch, step, global_step, name
|
||||
)
|
||||
weight_format = "hf"
|
||||
with_optim = False
|
||||
if self.for_recover:
|
||||
|
|
|
@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
enable_thinking: bool,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.enable_thinking = enable_thinking
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
data["messages"], tokenize=False, add_generation_prompt=True
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
text=text,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
@ -42,8 +47,8 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt=self.tokenizer.decode(input_ids),
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
|
|
|
@ -9,43 +9,55 @@ cluster:
|
|||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
seed: 1
|
||||
total_train_epochs: 1
|
||||
total_train_epochs: 10
|
||||
tokenizer_path: ${actor.path}
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: null
|
||||
max_concurrent_rollouts: 256
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 0
|
||||
|
||||
gconfig:
|
||||
n_samples: 4
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 1024
|
||||
max_new_tokens: 512
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
dtype: float16
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
weight_decay: 0.05
|
||||
lr: 1e-6
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1e-5
|
||||
lr_scheduler_type: cosine
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.2
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 4
|
||||
recompute_logprob: true
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
|
@ -62,7 +74,7 @@ server_only: false
|
|||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: false
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
|
@ -70,12 +82,12 @@ sglang:
|
|||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 128
|
||||
batch_size: 1024
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
batch_size: 1024
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
|
@ -100,8 +112,8 @@ evaluator:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: 1
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
|
|
|
@ -3,12 +3,13 @@ import re
|
|||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
|
@ -22,7 +23,7 @@ from realhf.base import stats_tracker
|
|||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages, "method": "strict"}
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
@ -35,7 +36,7 @@ def get_gsm8k_dataset(split, rank, world_size):
|
|||
|
||||
|
||||
# Adapted from verl.
|
||||
def extract_solution(solution_str, method="strict"):
|
||||
def extract_solution(solution_str, method="strict") -> str | None:
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
final_answer = None
|
||||
|
@ -62,13 +63,16 @@ def extract_solution(solution_str, method="strict"):
|
|||
return final_answer
|
||||
|
||||
|
||||
def gsm8k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, method, **kwargs
|
||||
):
|
||||
sol = extract_solution(solution_str=completions, method=method)
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
from realhf.impl.dataset.math_parser import extract_answer
|
||||
|
||||
sol = extract_answer(completions, data_name="math")
|
||||
ans = extract_solution(solution_str=answer, method="strict")
|
||||
if sol is None:
|
||||
return 0
|
||||
return int(sol == answer)
|
||||
if ans is None:
|
||||
return 0
|
||||
return int(sol.strip() == ans.strip())
|
||||
|
||||
|
||||
def main_grpo():
|
||||
|
@ -107,6 +111,8 @@ def main_grpo():
|
|||
rollout.initialize(None, ft_spec)
|
||||
eval_rollout = RemoteSGLangEngine(config.rollout)
|
||||
eval_rollout.initialize(None, ft_spec)
|
||||
# NOTE: set a large version such that eval does not have any offpolicyness control
|
||||
eval_rollout.set_version(int(1e12))
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
|
@ -122,7 +128,10 @@ def main_grpo():
|
|||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=gsm8k_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Run training.
|
||||
|
@ -137,11 +146,16 @@ def main_grpo():
|
|||
global_step = 0
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(train_dataloader):
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
batch = rollout.rollout(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
if config.actor.recompute_logprob:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
batch["logprobs"] = actor.compute_logp(batch)
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
@ -156,6 +170,22 @@ def main_grpo():
|
|||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=os.path.join(config.cluster.fileroot, "update_weights"),
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
rollout.set_version(global_step)
|
||||
dist.barrier()
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
|
@ -169,13 +199,13 @@ def main_grpo():
|
|||
eval_rollout.submit(item, workflow)
|
||||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"]
|
||||
rewards = batch["rewards"].float()
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
rewards.shape[0],
|
||||
device=rewards.device,
|
||||
dtype=rewards.dtype,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
)
|
||||
stats_tracker.stat(task_reward=rewards, denominator="n_seqs")
|
||||
|
@ -190,11 +220,16 @@ def main_grpo():
|
|||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
global_step += 1
|
||||
break
|
||||
break
|
||||
|
||||
actor.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
rollout.destroy()
|
||||
eval_rollout.destroy()
|
||||
logger.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
@ -115,6 +116,7 @@ def main_sft():
|
|||
|
||||
engine.destroy()
|
||||
logger.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -193,7 +193,7 @@ class DistributedStatsTracker:
|
|||
values = self.stats[key]
|
||||
if key not in self.denominators:
|
||||
x = sum([x.sum() for x in values])
|
||||
if reduce_group is not None:
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(x, group=reduce_group)
|
||||
else:
|
||||
denominator = self.denominators[key]
|
||||
|
@ -205,7 +205,7 @@ class DistributedStatsTracker:
|
|||
for v, d in zip(values, self.stats[denominator]):
|
||||
xs.append(torch.where(d, v, 0.0).sum())
|
||||
x = sum(xs)
|
||||
if reduce_group is not None:
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(x, group=reduce_group)
|
||||
return float(x)
|
||||
|
||||
|
@ -221,7 +221,7 @@ class DistributedStatsTracker:
|
|||
ds.append(d.sum())
|
||||
x = sum(xs)
|
||||
d = sum(ds)
|
||||
if reduce_group is not None:
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(x, group=reduce_group)
|
||||
dist.all_reduce(d, group=reduce_group)
|
||||
if d == 0:
|
||||
|
@ -237,7 +237,7 @@ class DistributedStatsTracker:
|
|||
for v, d in zip(values, self.stats[denominator]):
|
||||
xs.append(torch.where(d, v, float("inf")).min())
|
||||
x = min(xs)
|
||||
if reduce_group is not None:
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
|
||||
if torch.isinf(x):
|
||||
return None
|
||||
|
@ -252,7 +252,7 @@ class DistributedStatsTracker:
|
|||
for v, d in zip(values, self.stats[denominator]):
|
||||
xs.append(torch.where(d, v, -float("inf")).max())
|
||||
x = max(xs)
|
||||
if reduce_group is not None:
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
|
||||
if torch.isinf(x):
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue