mirror of https://github.com/inclusionAI/AReaL
32k run
This commit is contained in:
parent
eda0e79725
commit
932f9b9232
|
@ -181,6 +181,7 @@ class TrainEngineConfig:
|
|||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
disable_dropout: bool = False
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
|
|
|
@ -46,6 +46,7 @@ from arealite.utils.fsdp import (
|
|||
fsdp2_load_full_state_dict,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
@ -112,14 +113,14 @@ class FSDPEngine(TrainEngine):
|
|||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
||||
|
||||
model = AutoLigerKernelForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
|
@ -335,6 +336,9 @@ class FSDPEngine(TrainEngine):
|
|||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
|
@ -364,10 +368,11 @@ class FSDPEngine(TrainEngine):
|
|||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
@ -421,7 +426,7 @@ class FSDPEngine(TrainEngine):
|
|||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
@ -453,7 +458,7 @@ class FSDPEngine(TrainEngine):
|
|||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
|
|
|
@ -211,14 +211,13 @@ class PPOActor:
|
|||
|
||||
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
||||
data.pop(key, None)
|
||||
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
|
||||
self.engine.train()
|
||||
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
||||
data,
|
||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||
)
|
||||
for mb in mb_inputs.mbs:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
train_stat = self.engine.train_batch(
|
||||
mb,
|
||||
loss_fn=functools.partial(
|
||||
|
|
|
@ -312,7 +312,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
)
|
||||
bs = data["attention_mask"].shape[0]
|
||||
max_seqlen = data["attention_mask"].shape[1]
|
||||
input_lens = data["attention_mask"].sum(1).cpu().numpy()
|
||||
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
|
||||
|
||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||
to_split = {}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import socket
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tabulate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]:
|
||||
"""Get current memory usage."""
|
||||
assert unit in ["GB", "MB", "KB"]
|
||||
divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024
|
||||
mem_allocated = torch.cuda.memory_allocated()
|
||||
mem_reserved = torch.cuda.memory_reserved()
|
||||
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||
mem_used = mem_total - mem_free
|
||||
mem_allocated = f"{mem_allocated / divisor:.{precision}f}"
|
||||
mem_reserved = f"{mem_reserved / divisor:.{precision}f}"
|
||||
mem_used = f"{mem_used / divisor:.{precision}f}"
|
||||
mem_total = f"{mem_total / divisor:.{precision}f}"
|
||||
return mem_allocated, mem_reserved, mem_used, mem_total
|
||||
|
||||
|
||||
def log_gpu_stats(head: str, rank: int = 0):
|
||||
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
|
||||
mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()
|
||||
message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}"
|
||||
logger.info(msg=message)
|
|
@ -1,4 +1,8 @@
|
|||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.compile
|
||||
|
@ -11,7 +15,7 @@ def gather_logprobs(
|
|||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs_entropy(
|
||||
def _gather_logprobs_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||
):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||
|
@ -20,11 +24,33 @@ def gather_logprobs_entropy(
|
|||
return log_probs_labels, entropy
|
||||
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
def gather_logprobs_entropy(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
chunk_size: int = 1024,
|
||||
):
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
if batch_size <= chunk_size:
|
||||
return _gather_logprobs_entropy(logits, labels, temperature)
|
||||
|
||||
log_probs_labels_list = []
|
||||
entropy_list = []
|
||||
|
||||
for i in range(0, batch_size, chunk_size):
|
||||
end_idx = min(i + chunk_size, batch_size)
|
||||
chunk_logits = logits[i:end_idx]
|
||||
chunk_labels = labels[i:end_idx]
|
||||
|
||||
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
|
||||
chunk_logits, chunk_labels, temperature
|
||||
)
|
||||
|
||||
log_probs_labels_list.append(chunk_log_probs)
|
||||
entropy_list.append(chunk_entropy)
|
||||
|
||||
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0
|
|
@ -28,6 +28,7 @@ from arealite.api.workflow_api import RolloutWorkflow
|
|||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
|
@ -187,25 +188,23 @@ def main_grpo():
|
|||
batch["logprobs"] = logp
|
||||
else:
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("Recompute logp")
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
log_gpu_stats("Ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
|
||||
with stats_tracker.record_timing("clear_cache"):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("PPO update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
|
|
|
@ -10,7 +10,6 @@ cluster:
|
|||
name_resolve:
|
||||
type: etcd3
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
exclude: slurmd-69
|
||||
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
seed: 1
|
||||
|
@ -30,7 +29,7 @@ rollout:
|
|||
gconfig:
|
||||
n_samples: 16
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 15360
|
||||
max_new_tokens: 30720
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
|
@ -39,10 +38,11 @@ actor:
|
|||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 16384
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1e-6
|
||||
|
@ -63,8 +63,8 @@ actor:
|
|||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 1
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
recompute_logprob: false
|
||||
use_decoupled_loss: false
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
|
@ -72,9 +72,10 @@ ref:
|
|||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 16384
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
|
@ -91,7 +92,7 @@ sglang:
|
|||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 128
|
||||
batch_size: 16
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
|
|
Loading…
Reference in New Issue