0724_merge3

This commit is contained in:
朱晗 2025-07-24 15:22:12 +08:00
parent 176ec4bb23
commit 5118cfaea2
7 changed files with 46 additions and 42 deletions

View File

@ -7,7 +7,7 @@ import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
@ -69,7 +69,7 @@ def main(args):
rank=rank,
world_size=world_size,
split="train",
training_type=config.train_dataset.type,
type=config.train_dataset.type,
processor=processor,
)
valid_dataset = get_custom_dataset(
@ -77,7 +77,7 @@ def main(args):
rank=rank,
world_size=world_size,
split="test",
training_type=config.valid_dataset.type,
type=config.valid_dataset.type,
processor=processor,
)
# Create dataset and dataloaders

View File

@ -27,7 +27,7 @@ def main_sft():
rank=rank,
world_size=world_size,
split="train",
training_type=config.train_dataset.type,
type=config.train_dataset.type,
tokenizer=tokenizer,
processor=processor,
)
@ -36,7 +36,7 @@ def main_sft():
rank=rank,
world_size=world_size,
split="test",
training_type=config.valid_dataset.type,
type=config.valid_dataset.type,
tokenizer=tokenizer,
processor=processor,
)

View File

@ -1,10 +1,11 @@
experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
n_nodes: 1
n_gpus_per_node: 8
cluster:
fileroot: /tmp/arealite/experiments
n_nodes: 1
n_gpus_per_node: 8
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve

View File

@ -10,6 +10,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -23,19 +24,6 @@ from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("GSM8K grpo")
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
@ -53,10 +41,26 @@ def main(args):
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
tokenizer=tokenizer,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
@ -64,7 +68,7 @@ def main(args):
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", rank, world_size),
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,

View File

@ -7,6 +7,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
@ -16,25 +17,6 @@ from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_sft_dataset(dataset, tokenizer)
def main(args):
config, _ = load_expr_config(args, SFTConfig)
config: SFTConfig
@ -42,6 +24,23 @@ def main(args):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
type=config.train_dataset.type,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
type=config.valid_dataset.type,
tokenizer=tokenizer,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(

View File

@ -31,7 +31,7 @@ dependencies = [
"huggingface_hub",
"datasets",
"accelerate",
"transformers>=4.53.3",
"transformers>=4.53.1",
# Scientific computing
"numpy<2.0.0",