This commit is contained in:
bowei.fw 2025-07-13 21:48:48 +08:00
parent eda0e79725
commit 932f9b9232
9 changed files with 103 additions and 29 deletions

View File

@ -181,6 +181,7 @@ class TrainEngineConfig:
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
disable_dropout: bool = False
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)

View File

@ -46,6 +46,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
)
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
@ -112,14 +113,14 @@ class FSDPEngine(TrainEngine):
attn_implementation=self.config.attn_impl,
)
else:
from liger_kernel.transformers import AutoLigerKernelForCausalLM
model = AutoLigerKernelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
@ -335,6 +336,9 @@ class FSDPEngine(TrainEngine):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
@ -364,10 +368,11 @@ class FSDPEngine(TrainEngine):
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
outputs = self.model(**padded_mb_input)
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input, use_cache=False)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
@ -421,7 +426,7 @@ class FSDPEngine(TrainEngine):
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
outputs = self.model(**padded_mb_input, use_cache=False)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
@ -453,7 +458,7 @@ class FSDPEngine(TrainEngine):
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
outputs = self.model(**padded_mb_input, use_cache=False)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits

View File

@ -211,14 +211,13 @@ class PPOActor:
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
data.pop(key, None)
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
self.engine.train()
mb_inputs = split_padded_tensor_dict_into_mb_list(
data,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
)
for mb in mb_inputs.mbs:
gc.collect()
torch.cuda.empty_cache()
gc.collect()
train_stat = self.engine.train_batch(
mb,
loss_fn=functools.partial(

View File

@ -312,7 +312,7 @@ def split_padded_tensor_dict_into_mb_list(
)
bs = data["attention_mask"].shape[0]
max_seqlen = data["attention_mask"].shape[1]
input_lens = data["attention_mask"].sum(1).cpu().numpy()
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
# check tensor shape, split only 1d tensors with length "total_lens"
to_split = {}

35
arealite/utils/device.py Normal file
View File

@ -0,0 +1,35 @@
import os
import socket
from functools import lru_cache
from typing import Tuple
import numpy as np
import tabulate
import torch
import torch.distributed as dist
from realhf.base import logging
logger = logging.getLogger(__file__)
def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]:
"""Get current memory usage."""
assert unit in ["GB", "MB", "KB"]
divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024
mem_allocated = torch.cuda.memory_allocated()
mem_reserved = torch.cuda.memory_reserved()
mem_free, mem_total = torch.cuda.mem_get_info()
mem_used = mem_total - mem_free
mem_allocated = f"{mem_allocated / divisor:.{precision}f}"
mem_reserved = f"{mem_reserved / divisor:.{precision}f}"
mem_used = f"{mem_used / divisor:.{precision}f}"
mem_total = f"{mem_total / divisor:.{precision}f}"
return mem_allocated, mem_reserved, mem_used, mem_total
def log_gpu_stats(head: str, rank: int = 0):
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()
message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}"
logger.info(msg=message)

View File

@ -1,4 +1,8 @@
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
@torch.compile
@ -11,7 +15,7 @@ def gather_logprobs(
@torch.compile
def gather_logprobs_entropy(
def _gather_logprobs_entropy(
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
):
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
@ -20,11 +24,33 @@ def gather_logprobs_entropy(
return log_probs_labels, entropy
from typing import Dict, Optional, Tuple
def gather_logprobs_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
chunk_size: int = 1024,
):
batch_size = logits.shape[0]
import numpy as np
import torch
import torch.distributed as dist
if batch_size <= chunk_size:
return _gather_logprobs_entropy(logits, labels, temperature)
log_probs_labels_list = []
entropy_list = []
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
chunk_logits = logits[i:end_idx]
chunk_labels = labels[i:end_idx]
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
chunk_logits, chunk_labels, temperature
)
log_probs_labels_list.append(chunk_log_probs)
entropy_list.append(chunk_entropy)
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
@torch.no_grad()

8
arealite/utils/model.py Normal file
View File

@ -0,0 +1,8 @@
import torch
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0

View File

@ -28,6 +28,7 @@ from arealite.api.workflow_api import RolloutWorkflow
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.data import concat_padded_tensors
from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
@ -187,25 +188,23 @@ def main_grpo():
batch["logprobs"] = logp
else:
batch["prox_logp"] = logp
log_gpu_stats("Recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("Ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
with stats_tracker.record_timing("clear_cache"):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("PPO update")
with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta(

View File

@ -10,7 +10,6 @@ cluster:
name_resolve:
type: etcd3
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
exclude: slurmd-69
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
seed: 1
@ -30,7 +29,7 @@ rollout:
gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 15360
max_new_tokens: 30720
greedy: false
temperature: 1.0
@ -39,10 +38,11 @@ actor:
trial_name: ${trial_name}
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 16384
max_tokens_per_mb: 32768
optimizer:
type: adam
lr: 1e-6
@ -63,8 +63,8 @@ actor:
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
recompute_logprob: false
use_decoupled_loss: false
behav_imp_weight_cap: 5.0
ref:
@ -72,9 +72,10 @@ ref:
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 16384
max_tokens_per_mb: 32768
optimizer: null
backend: fsdp
@ -91,7 +92,7 @@ sglang:
# datasets
train_dataset:
batch_size: 128
batch_size: 16
shuffle: true
pin_memory: true