ready for boba

This commit is contained in:
bowei.fw 2025-07-12 19:20:02 +08:00
parent 43b3c3f8d0
commit 34b6941a2f
7 changed files with 391 additions and 9 deletions

View File

@ -756,7 +756,7 @@ def parse_cli_args(argv: List[str]):
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).rstrip(".yaml"),
config_name=str(relpath.name).split(".yaml")[0],
overrides=overrides,
)
return cfg, config_file

View File

@ -95,7 +95,6 @@ class RemoteSGLangEngine(InferenceEngine):
)
return response.status_code == 200
except requests.exceptions.RequestException as e:
print(f"Check {base_url}/metrics failed, reason: {e}")
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):

View File

@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
def concat_padded_tensors(
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
) -> Dict[str, torch.Tensor]:
tensor_dicts: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return {}
return TensorDict()
# Find max sequence length across all dictionaries
lens = []
@ -156,7 +156,7 @@ def concat_padded_tensors(
result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result:
result["attention_mask"] = attn_mask
return result
return TensorDict(result, batch_size=[len(lens)])
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:

View File

@ -21,6 +21,18 @@ class Saver:
freq_sec=config.freq_secs,
)
@staticmethod
def get_save_checkpoint_root(
config: SaverConfig,
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
)
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_save_checkpoint_path(
config: SaverConfig,
@ -30,8 +42,7 @@ class Saver:
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
Saver.get_save_checkpoint_root(config, name),
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)

236
examples/arealite/boba.py Normal file
View File

@ -0,0 +1,236 @@
import asyncio
import os
import re
import sys
import uuid
import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import (
GenerationHyperparameters,
GRPOConfig,
load_expr_config,
)
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
WeightUpdateMeta,
)
from arealite.api.workflow_api import RolloutWorkflow
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.data import concat_padded_tensors
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
async def arun_episode(self, engine, data):
input_ids = self.tokenizer.encode(data["prompt"])
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0] * resp.input_len + resp.output_logprobs
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
completions=self.tokenizer.decode(resp.output_tokens),
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq)).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
)
results.append(TensorDict(res, batch_size=[1]))
return concat_padded_tensors(results)
def get_boba_math_dataset(rank, world_size):
dataset = load_dataset(
path="json",
split="train",
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
)
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def boba_reward_fn(
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
):
from realhf.impl.dataset.math_parser import process_results
label = 0
for sol in solutions:
label = label or process_results(completions, sol)[0]
return label
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_boba_math_dataset(rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=boba_reward_fn, gconfig=config.gconfig, tokenizer=tokenizer
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(
data_generator,
train_dataloader,
workflow=workflow,
)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout(data, workflow=workflow)
batch = batch.to(actor.device)
if config.actor.recompute_logprob:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
if not config.actor.use_decoupled_loss:
batch["logprobs"] = logp
else:
batch["prox_logp"] = logp
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta(
type="disk",
path=os.path.join(
Saver.get_save_checkpoint_root(config.saver), "update_weights"
),
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
rollout.set_version(global_step + 1)
dist.barrier()
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
logger.commit(epoch, step, global_step, stats)
actor.destroy()
if ref is not None:
ref.destroy()
rollout.destroy()
logger.close()
dist.destroy_process_group()
if __name__ == "__main__":
main_grpo()

View File

@ -0,0 +1,134 @@
experiment_name: lite-boba-math
trial_name: run1
allocation_mode: sglang.d96p1t1+d32p1t1
n_nodes: 16
n_gpus_per_node: 8
cluster:
cluster_name: na132
fileroot: /storage/openpsi/experiments
mount: /storage:/storage
name_resolve:
type: etcd3
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
gpu_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
gpu_infer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
seed: 1
total_train_epochs: 10
tokenizer_path: ${actor.path}
async_training: true
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: null
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: false
gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 27648
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
init_from_scratch: false
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 32768
optimizer:
type: adam
lr: 1e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 32768
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 128
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled
# Launcher
launcher:
inference_server_cpus_per_gpu: 15
inference_server_mem_per_gpu: 153600
trainer_cpus_per_gpu: 15
trainer_mem_per_gpu: 153600

View File

@ -191,7 +191,9 @@ def main_grpo():
with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta(
type="disk",
path=os.path.join(config.cluster.fileroot, "update_weights"),
path=os.path.join(
Saver.get_save_checkpoint_root(config.saver), "update_weights"
),
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,