mirror of https://github.com/inclusionAI/AReaL
0724_merge3
This commit is contained in:
parent
176ec4bb23
commit
5118cfaea2
|
@ -7,7 +7,7 @@ import torch.distributed as dist
|
|||
import wandb
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
|
||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
|
@ -69,7 +69,7 @@ def main(args):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type=config.train_dataset.type,
|
||||
type=config.train_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
|
@ -77,7 +77,7 @@ def main(args):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type=config.valid_dataset.type,
|
||||
type=config.valid_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
# Create dataset and dataloaders
|
||||
|
|
|
@ -27,7 +27,7 @@ def main_sft():
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type=config.train_dataset.type,
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ def main_sft():
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type=config.valid_dataset.type,
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
experiment_name: gsm8k-grpo
|
||||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
cluster:
|
||||
fileroot: /tmp/arealite/experiments
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
|
|
|
@ -10,6 +10,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
|||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
|
@ -23,19 +24,6 @@ from realhf.base import logging, seeding, stats_tracker
|
|||
logger = logging.getLogger("GSM8K grpo")
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
|
@ -53,10 +41,26 @@ def main(args):
|
|||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
|
@ -64,7 +68,7 @@ def main(args):
|
|||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", rank, world_size),
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
|
|
|
@ -7,6 +7,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
|||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
|
@ -16,25 +17,6 @@ from realhf.api.core.data_api import load_hf_tokenizer
|
|||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
|
||||
return {"input_ids": seq_token, "loss_mask": loss_mask}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
|
||||
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, SFTConfig)
|
||||
config: SFTConfig
|
||||
|
@ -42,6 +24,23 @@ def main(args):
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
|
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers>=4.53.3",
|
||||
"transformers>=4.53.1",
|
||||
|
||||
# Scientific computing
|
||||
"numpy<2.0.0",
|
||||
|
|
Loading…
Reference in New Issue