mirror of https://github.com/inclusionAI/AReaL
0731_2
This commit is contained in:
parent
d6a6240655
commit
c5cd21d5db
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||
|
||||
import transformers
|
||||
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k"]
|
||||
|
||||
|
||||
def get_custom_dataset(
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
|
||||
VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]
|
||||
|
||||
def custom_reward_fn(
|
||||
path: str,
|
||||
**kwargs,
|
||||
):
|
||||
if "clevr_count_70k" in path:
|
||||
from examples.arealite.reward.clevr_count_70k import clevr_count_70k_reward_fn
|
||||
return clevr_count_70k_reward_fn
|
||||
elif "geometry3k" in path:
|
||||
from examples.arealite.reward.geometry3k import geometry3k_reward_fn
|
||||
return geometry3k_reward_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Reward function {path} is not supported. "
|
||||
f"Supported reward functions are: {VALID_REWARD_FN}. "
|
||||
)
|
|
@ -105,9 +105,12 @@ valid_dataset:
|
|||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
path: ${train_dataset.path}
|
||||
type: rl
|
||||
|
||||
reward_fn:
|
||||
path: ${train_dataset.path}
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
|
@ -136,4 +139,7 @@ evaluator:
|
|||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
fileroot: ${cluster.fileroot}
|
||||
|
||||
wandb:
|
||||
project: clevr_count_70k-grpo
|
|
@ -0,0 +1,145 @@
|
|||
experiment_name: clevr_count_70k-grpo
|
||||
trial_name: trial1
|
||||
|
||||
|
||||
seed: 1
|
||||
total_train_epochs: 3
|
||||
tokenizer_path: ${actor.path}
|
||||
async_training: true
|
||||
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/geometry3k-grpo
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
|
||||
allocation_mode: sglang.d1p1t1+d7p1t1
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: 256
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: false
|
||||
|
||||
gconfig:
|
||||
n_samples: 4
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 512
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: false
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-6
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 1
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.7
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 32
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: hiyouga/geometry3k
|
||||
type: rl
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 32
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: ${train_dataset.path}
|
||||
type: rl
|
||||
|
||||
reward_fn:
|
||||
path: ${train_dataset.path}
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
|
||||
wandb:
|
||||
project: geometry3k-grpo
|
|
@ -0,0 +1,26 @@
|
|||
import re
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
return 1
|
||||
|
||||
return 0
|
|
@ -0,0 +1,42 @@
|
|||
import re
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def geometry3k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
is_numeric = sol.replace('.', '', 1).isdigit() # Allows for decimal check
|
||||
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
# Exact answer matching
|
||||
if sol == ans :
|
||||
reward = 1
|
||||
elif is_numeric and abs(float(sol) - float(ans)) < 1e-4:
|
||||
reward = 0.8 # Reward for correct numerical approximation
|
||||
elif is_latex:
|
||||
# Check if numbers in LaTeX are correct
|
||||
expected_numbers = re.findall(r'-?\d+\.?\d*', ans) # Find all numbers in expected answer
|
||||
predicted_numbers = re.findall(r'-?\d+\.?\d*', sol) # Find all numbers in predicted answer
|
||||
|
||||
if len(expected_numbers) == len(predicted_numbers) and all(
|
||||
abs(float(pred) - float(exp)) < 1e-4 for pred, exp in zip(predicted_numbers, expected_numbers)
|
||||
):
|
||||
reward = 0.6
|
||||
else:
|
||||
reward = 0
|
||||
|
||||
return reward
|
|
@ -11,6 +11,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
|||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
from arealite.reward.__init__ import custom_reward_fn
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
|
@ -22,35 +23,12 @@ from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
|||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
wandb.init(project="clevr_70k")
|
||||
wandb.init(project=config.wandb.project)
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
@ -134,9 +112,13 @@ def main(args):
|
|||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
reward_fn = custom_reward_fn(
|
||||
path=config.reward_fn.path,
|
||||
)
|
||||
|
||||
workflow = VisionRLVRWorkflow(
|
||||
reward_fn=clevr_count_70k_reward_fn,
|
||||
reward_fn=reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
Loading…
Reference in New Issue