mirror of https://github.com/inclusionAI/AReaL
32k run
This commit is contained in:
parent
eda0e79725
commit
932f9b9232
|
@ -181,6 +181,7 @@ class TrainEngineConfig:
|
||||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||||
|
|
||||||
# Training Backend Configuration
|
# Training Backend Configuration
|
||||||
|
disable_dropout: bool = False
|
||||||
gradient_checkpointing: bool = field(
|
gradient_checkpointing: bool = field(
|
||||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||||
)
|
)
|
||||||
|
|
|
@ -46,6 +46,7 @@ from arealite.utils.fsdp import (
|
||||||
fsdp2_load_full_state_dict,
|
fsdp2_load_full_state_dict,
|
||||||
get_cosine_schedule_with_warmup,
|
get_cosine_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
from arealite.utils.model import disable_dropout_in_model
|
||||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from realhf.base import logging, name_resolve, names, pkg_version
|
||||||
|
@ -112,14 +113,14 @@ class FSDPEngine(TrainEngine):
|
||||||
attn_implementation=self.config.attn_impl,
|
attn_implementation=self.config.attn_impl,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
|
||||||
model = AutoLigerKernelForCausalLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.config.path,
|
pretrained_model_name_or_path=self.config.path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
attn_implementation=self.config.attn_impl,
|
attn_implementation=self.config.attn_impl,
|
||||||
)
|
)
|
||||||
|
if self.config.disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
if self.config.gradient_checkpointing:
|
if self.config.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
|
@ -335,6 +336,9 @@ class FSDPEngine(TrainEngine):
|
||||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||||
input_ = amend_position_ids(input_)
|
input_ = amend_position_ids(input_)
|
||||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||||
|
logger.info(
|
||||||
|
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||||
|
)
|
||||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||||
|
@ -364,10 +368,11 @@ class FSDPEngine(TrainEngine):
|
||||||
dist.all_reduce(total_loss_weight)
|
dist.all_reduce(total_loss_weight)
|
||||||
|
|
||||||
# Process microbatches with gradient accumulation
|
# Process microbatches with gradient accumulation
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||||
):
|
):
|
||||||
outputs = self.model(**padded_mb_input)
|
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||||
|
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||||
|
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
@ -421,7 +426,7 @@ class FSDPEngine(TrainEngine):
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
):
|
):
|
||||||
outputs = self.model(**padded_mb_input)
|
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
loss = loss_fn(logits, mb_input)
|
loss = loss_fn(logits, mb_input)
|
||||||
|
@ -453,7 +458,7 @@ class FSDPEngine(TrainEngine):
|
||||||
for pad_length, padded_mb_input, mb_input in zip(
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
):
|
):
|
||||||
outputs = self.model(**padded_mb_input)
|
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||||
logits = outputs.logits.squeeze(0)
|
logits = outputs.logits.squeeze(0)
|
||||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
|
||||||
|
|
|
@ -211,14 +211,13 @@ class PPOActor:
|
||||||
|
|
||||||
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
||||||
data.pop(key, None)
|
data.pop(key, None)
|
||||||
|
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
|
||||||
|
self.engine.train()
|
||||||
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
||||||
data,
|
data,
|
||||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||||
)
|
)
|
||||||
for mb in mb_inputs.mbs:
|
for mb in mb_inputs.mbs:
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
train_stat = self.engine.train_batch(
|
train_stat = self.engine.train_batch(
|
||||||
mb,
|
mb,
|
||||||
loss_fn=functools.partial(
|
loss_fn=functools.partial(
|
||||||
|
|
|
@ -312,7 +312,7 @@ def split_padded_tensor_dict_into_mb_list(
|
||||||
)
|
)
|
||||||
bs = data["attention_mask"].shape[0]
|
bs = data["attention_mask"].shape[0]
|
||||||
max_seqlen = data["attention_mask"].shape[1]
|
max_seqlen = data["attention_mask"].shape[1]
|
||||||
input_lens = data["attention_mask"].sum(1).cpu().numpy()
|
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
|
||||||
|
|
||||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||||
to_split = {}
|
to_split = {}
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tabulate
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from realhf.base import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]:
|
||||||
|
"""Get current memory usage."""
|
||||||
|
assert unit in ["GB", "MB", "KB"]
|
||||||
|
divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024
|
||||||
|
mem_allocated = torch.cuda.memory_allocated()
|
||||||
|
mem_reserved = torch.cuda.memory_reserved()
|
||||||
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||||
|
mem_used = mem_total - mem_free
|
||||||
|
mem_allocated = f"{mem_allocated / divisor:.{precision}f}"
|
||||||
|
mem_reserved = f"{mem_reserved / divisor:.{precision}f}"
|
||||||
|
mem_used = f"{mem_used / divisor:.{precision}f}"
|
||||||
|
mem_total = f"{mem_total / divisor:.{precision}f}"
|
||||||
|
return mem_allocated, mem_reserved, mem_used, mem_total
|
||||||
|
|
||||||
|
|
||||||
|
def log_gpu_stats(head: str, rank: int = 0):
|
||||||
|
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
|
||||||
|
mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()
|
||||||
|
message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}"
|
||||||
|
logger.info(msg=message)
|
|
@ -1,4 +1,8 @@
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
|
@ -11,7 +15,7 @@ def gather_logprobs(
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
def gather_logprobs_entropy(
|
def _gather_logprobs_entropy(
|
||||||
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0
|
||||||
):
|
):
|
||||||
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
log_probs = torch.nn.functional.log_softmax(logits.float() / temperature, dim=-1)
|
||||||
|
@ -20,11 +24,33 @@ def gather_logprobs_entropy(
|
||||||
return log_probs_labels, entropy
|
return log_probs_labels, entropy
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple
|
def gather_logprobs_entropy(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
chunk_size: int = 1024,
|
||||||
|
):
|
||||||
|
batch_size = logits.shape[0]
|
||||||
|
|
||||||
import numpy as np
|
if batch_size <= chunk_size:
|
||||||
import torch
|
return _gather_logprobs_entropy(logits, labels, temperature)
|
||||||
import torch.distributed as dist
|
|
||||||
|
log_probs_labels_list = []
|
||||||
|
entropy_list = []
|
||||||
|
|
||||||
|
for i in range(0, batch_size, chunk_size):
|
||||||
|
end_idx = min(i + chunk_size, batch_size)
|
||||||
|
chunk_logits = logits[i:end_idx]
|
||||||
|
chunk_labels = labels[i:end_idx]
|
||||||
|
|
||||||
|
chunk_log_probs, chunk_entropy = _gather_logprobs_entropy(
|
||||||
|
chunk_logits, chunk_labels, temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probs_labels_list.append(chunk_log_probs)
|
||||||
|
entropy_list.append(chunk_entropy)
|
||||||
|
|
||||||
|
return torch.cat(log_probs_labels_list), torch.cat(entropy_list)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from trl
|
||||||
|
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, torch.nn.Dropout):
|
||||||
|
module.p = 0
|
|
@ -28,6 +28,7 @@ from arealite.api.workflow_api import RolloutWorkflow
|
||||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
from arealite.utils.device import log_gpu_stats
|
||||||
from arealite.utils.evaluator import Evaluator
|
from arealite.utils.evaluator import Evaluator
|
||||||
from arealite.utils.saver import Saver
|
from arealite.utils.saver import Saver
|
||||||
from arealite.utils.stats_logger import StatsLogger
|
from arealite.utils.stats_logger import StatsLogger
|
||||||
|
@ -187,25 +188,23 @@ def main_grpo():
|
||||||
batch["logprobs"] = logp
|
batch["logprobs"] = logp
|
||||||
else:
|
else:
|
||||||
batch["prox_logp"] = logp
|
batch["prox_logp"] = logp
|
||||||
|
log_gpu_stats("Recompute logp")
|
||||||
|
|
||||||
if ref is not None:
|
if ref is not None:
|
||||||
with stats_tracker.record_timing("ref_logp"):
|
with stats_tracker.record_timing("ref_logp"):
|
||||||
batch["ref_logp"] = ref.compute_logp(batch)
|
batch["ref_logp"] = ref.compute_logp(batch)
|
||||||
|
log_gpu_stats("Ref logp")
|
||||||
|
|
||||||
with stats_tracker.record_timing("compute_advantage"):
|
with stats_tracker.record_timing("compute_advantage"):
|
||||||
actor.compute_advantages(batch)
|
actor.compute_advantages(batch)
|
||||||
|
|
||||||
with stats_tracker.record_timing("clear_cache"):
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
stats_tracker.record_timing("train_step"),
|
stats_tracker.record_timing("train_step"),
|
||||||
stats_tracker.scope("grpo_actor"),
|
stats_tracker.scope("grpo_actor"),
|
||||||
):
|
):
|
||||||
stats = actor.ppo_update(batch)
|
stats = actor.ppo_update(batch)
|
||||||
actor.step_lr_scheduler()
|
actor.step_lr_scheduler()
|
||||||
|
log_gpu_stats("PPO update")
|
||||||
|
|
||||||
with stats_tracker.record_timing("update_weights"):
|
with stats_tracker.record_timing("update_weights"):
|
||||||
meta = WeightUpdateMeta(
|
meta = WeightUpdateMeta(
|
||||||
|
|
|
@ -10,7 +10,6 @@ cluster:
|
||||||
name_resolve:
|
name_resolve:
|
||||||
type: etcd3
|
type: etcd3
|
||||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||||
exclude: slurmd-69
|
|
||||||
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||||
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||||
seed: 1
|
seed: 1
|
||||||
|
@ -30,7 +29,7 @@ rollout:
|
||||||
gconfig:
|
gconfig:
|
||||||
n_samples: 16
|
n_samples: 16
|
||||||
min_new_tokens: 0
|
min_new_tokens: 0
|
||||||
max_new_tokens: 15360
|
max_new_tokens: 30720
|
||||||
greedy: false
|
greedy: false
|
||||||
temperature: 1.0
|
temperature: 1.0
|
||||||
|
|
||||||
|
@ -39,10 +38,11 @@ actor:
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||||
init_from_scratch: false
|
init_from_scratch: false
|
||||||
|
disable_dropout: true
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
dtype: bfloat16
|
dtype: bfloat16
|
||||||
mb_spec:
|
mb_spec:
|
||||||
max_tokens_per_mb: 16384
|
max_tokens_per_mb: 32768
|
||||||
optimizer:
|
optimizer:
|
||||||
type: adam
|
type: adam
|
||||||
lr: 1e-6
|
lr: 1e-6
|
||||||
|
@ -63,8 +63,8 @@ actor:
|
||||||
reward_bias: -0.5
|
reward_bias: -0.5
|
||||||
kl_ctl: 0.0
|
kl_ctl: 0.0
|
||||||
ppo_n_minibatches: 1
|
ppo_n_minibatches: 1
|
||||||
recompute_logprob: true
|
recompute_logprob: false
|
||||||
use_decoupled_loss: true
|
use_decoupled_loss: false
|
||||||
behav_imp_weight_cap: 5.0
|
behav_imp_weight_cap: 5.0
|
||||||
|
|
||||||
ref:
|
ref:
|
||||||
|
@ -72,9 +72,10 @@ ref:
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
path: ${actor.path}
|
path: ${actor.path}
|
||||||
init_from_scratch: false
|
init_from_scratch: false
|
||||||
|
disable_dropout: true
|
||||||
dtype: ${actor.dtype}
|
dtype: ${actor.dtype}
|
||||||
mb_spec:
|
mb_spec:
|
||||||
max_tokens_per_mb: 16384
|
max_tokens_per_mb: 32768
|
||||||
optimizer: null
|
optimizer: null
|
||||||
backend: fsdp
|
backend: fsdp
|
||||||
|
|
||||||
|
@ -91,7 +92,7 @@ sglang:
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
train_dataset:
|
train_dataset:
|
||||||
batch_size: 128
|
batch_size: 16
|
||||||
shuffle: true
|
shuffle: true
|
||||||
pin_memory: true
|
pin_memory: true
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue