mirror of https://github.com/inclusionAI/AReaL
ready for boba
This commit is contained in:
parent
43b3c3f8d0
commit
34b6941a2f
|
@ -756,7 +756,7 @@ def parse_cli_args(argv: List[str]):
|
|||
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
config_name=str(relpath.name).split(".yaml")[0],
|
||||
overrides=overrides,
|
||||
)
|
||||
return cfg, config_file
|
||||
|
|
|
@ -95,7 +95,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Check {base_url}/metrics failed, reason: {e}")
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
|
|
|
@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return {}
|
||||
return TensorDict()
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
lens = []
|
||||
|
@ -156,7 +156,7 @@ def concat_padded_tensors(
|
|||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
if "attention_mask" not in result:
|
||||
result["attention_mask"] = attn_mask
|
||||
return result
|
||||
return TensorDict(result, batch_size=[len(lens)])
|
||||
|
||||
|
||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||
|
|
|
@ -21,6 +21,18 @@ class Saver:
|
|||
freq_sec=config.freq_secs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_root(
|
||||
config: SaverConfig,
|
||||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_path(
|
||||
config: SaverConfig,
|
||||
|
@ -30,8 +42,7 @@ class Saver:
|
|||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
Saver.get_save_checkpoint_root(config, name),
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
GRPOConfig,
|
||||
load_expr_config,
|
||||
)
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
input_ids = self.tokenizer.encode(data["prompt"])
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0] * resp.input_len + resp.output_logprobs
|
||||
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq)).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
|
||||
def get_boba_math_dataset(rank, world_size):
|
||||
dataset = load_dataset(
|
||||
path="json",
|
||||
split="train",
|
||||
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
|
||||
)
|
||||
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def boba_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
||||
):
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
label = 0
|
||||
for sol in solutions:
|
||||
label = label or process_results(completions, sol)[0]
|
||||
return label
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_boba_math_dataset(rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=boba_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(
|
||||
data_generator,
|
||||
train_dataloader,
|
||||
workflow=workflow,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
if config.actor.recompute_logprob:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
if not config.actor.use_decoupled_loss:
|
||||
batch["logprobs"] = logp
|
||||
else:
|
||||
batch["prox_logp"] = logp
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver), "update_weights"
|
||||
),
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
rollout.set_version(global_step + 1)
|
||||
dist.barrier()
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
|
||||
actor.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
rollout.destroy()
|
||||
logger.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
|
@ -0,0 +1,134 @@
|
|||
experiment_name: lite-boba-math
|
||||
trial_name: run1
|
||||
allocation_mode: sglang.d96p1t1+d32p1t1
|
||||
n_nodes: 16
|
||||
n_gpus_per_node: 8
|
||||
cluster:
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
mount: /storage:/storage
|
||||
name_resolve:
|
||||
type: etcd3
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
seed: 1
|
||||
total_train_epochs: 10
|
||||
tokenizer_path: ${actor.path}
|
||||
async_training: true
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: null
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: false
|
||||
|
||||
gconfig:
|
||||
n_samples: 16
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 27648
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1e-6
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 1
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
|
||||
# Launcher
|
||||
launcher:
|
||||
inference_server_cpus_per_gpu: 15
|
||||
inference_server_mem_per_gpu: 153600
|
||||
trainer_cpus_per_gpu: 15
|
||||
trainer_mem_per_gpu: 153600
|
|
@ -191,7 +191,9 @@ def main_grpo():
|
|||
with stats_tracker.record_timing("update_weights"):
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=os.path.join(config.cluster.fileroot, "update_weights"),
|
||||
path=os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver), "update_weights"
|
||||
),
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
|
|
Loading…
Reference in New Issue