This commit is contained in:
antoinegg1 2025-07-31 15:57:07 +08:00
parent d6a6240655
commit c5cd21d5db
7 changed files with 250 additions and 28 deletions

View File

@ -2,7 +2,7 @@ from typing import Optional
import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k"]
def get_custom_dataset(

View File

@ -0,0 +1,21 @@
from typing import Optional
import transformers
VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]
def custom_reward_fn(
path: str,
**kwargs,
):
if "clevr_count_70k" in path:
from examples.arealite.reward.clevr_count_70k import clevr_count_70k_reward_fn
return clevr_count_70k_reward_fn
elif "geometry3k" in path:
from examples.arealite.reward.geometry3k import geometry3k_reward_fn
return geometry3k_reward_fn
else:
raise ValueError(
f"Reward function {path} is not supported. "
f"Supported reward functions are: {VALID_REWARD_FN}. "
)

View File

@ -105,9 +105,12 @@ valid_dataset:
shuffle: true
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
path: ${train_dataset.path}
type: rl
reward_fn:
path: ${train_dataset.path}
# Utilities
saver:
experiment_name: ${experiment_name}
@ -136,4 +139,7 @@ evaluator:
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
fileroot: ${cluster.fileroot}
wandb:
project: clevr_count_70k-grpo

View File

@ -0,0 +1,145 @@
experiment_name: clevr_count_70k-grpo
trial_name: trial1
seed: 1
total_train_epochs: 3
tokenizer_path: ${actor.path}
async_training: true
cluster:
n_nodes: 1
n_gpus_per_node: 8
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/geometry3k-grpo
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
allocation_mode: sglang.d1p1t1+d7p1t1
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 256
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: false
gconfig:
n_samples: 4
min_new_tokens: 0
max_new_tokens: 512
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: false
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 2e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 10240
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.7
# datasets
train_dataset:
batch_size: 32
shuffle: true
pin_memory: true
num_workers: 4
path: hiyouga/geometry3k
type: rl
valid_dataset:
batch_size: 32
shuffle: true
pin_memory: true
num_workers: 4
path: ${train_dataset.path}
type: rl
reward_fn:
path: ${train_dataset.path}
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
project: geometry3k-grpo

View File

@ -0,0 +1,26 @@
import re
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return ""
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
return 1
return 0

View File

@ -0,0 +1,42 @@
import re
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return ""
def geometry3k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
is_numeric = sol.replace('.', '', 1).isdigit() # Allows for decimal check
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol
print(f"completions: {completions}, answer: {answer}")
# Exact answer matching
if sol == ans :
reward = 1
elif is_numeric and abs(float(sol) - float(ans)) < 1e-4:
reward = 0.8 # Reward for correct numerical approximation
elif is_latex:
# Check if numbers in LaTeX are correct
expected_numbers = re.findall(r'-?\d+\.?\d*', ans) # Find all numbers in expected answer
predicted_numbers = re.findall(r'-?\d+\.?\d*', sol) # Find all numbers in predicted answer
if len(expected_numbers) == len(predicted_numbers) and all(
abs(float(pred) - float(exp)) < 1e-4 for pred, exp in zip(predicted_numbers, expected_numbers)
):
reward = 0.6
else:
reward = 0
return reward

View File

@ -11,6 +11,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
from arealite.reward.__init__ import custom_reward_fn
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -22,35 +23,12 @@ from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return ""
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
return 1
return 0
def main(args):
wandb.init(project="clevr_70k")
wandb.init(project=config.wandb.project)
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
@ -134,9 +112,13 @@ def main(args):
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
reward_fn = custom_reward_fn(
path=config.reward_fn.path,
)
workflow = VisionRLVRWorkflow(
reward_fn=clevr_count_70k_reward_fn,
reward_fn=reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
processor=processor,