This commit is contained in:
bowei.fw 2025-07-14 14:58:28 +08:00
parent 037adedc70
commit 938c06a652
3 changed files with 39 additions and 5 deletions

View File

@ -33,7 +33,9 @@ from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
from realhf.base import logging, stats_tracker
logger = logging.getLogger("boba math")
class RLVRWorkflow(RolloutWorkflow):
@ -98,11 +100,33 @@ def get_boba_math_dataset(tokenizer, rank, world_size):
def boba_reward_fn(
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
):
from pebble import ProcessExpired, ProcessPool
from realhf.impl.dataset.math_parser import process_results
jobs = []
with ProcessPool(max_workers=1) as executor:
for sol in solutions:
job = executor.schedule(
process_results, args=[completions, sol], timeout=15
)
jobs.append(job)
label = 0
for sol in solutions:
label = label or process_results(completions, sol)[0]
for job in jobs:
try:
x = job.result()
except TimeoutError:
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
return label
@ -180,6 +204,8 @@ def main_grpo():
batch = rollout.rollout(data, workflow=workflow)
batch = batch.to(actor.device)
dist.barrier()
torch.cuda.synchronize()
if config.actor.recompute_logprob:
with stats_tracker.record_timing("recompute_logp"):
@ -189,19 +215,27 @@ def main_grpo():
else:
batch["prox_logp"] = logp
log_gpu_stats("Recompute logp")
dist.barrier()
torch.cuda.synchronize()
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("Ref logp")
dist.barrier()
torch.cuda.synchronize()
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
dist.barrier()
torch.cuda.synchronize()
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
dist.barrier()
torch.cuda.synchronize()
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("PPO update")

View File

@ -24,7 +24,7 @@ rollout:
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: false
enable_rollout_tracing: true
gconfig:
n_samples: 16

2
run.sh
View File

@ -2,4 +2,4 @@
WANDB_API_KEY=local-5dd08fc1894114d0bea728566d5c35c5b31ee608 \
WANDB_BASE_URL=http://8.150.1.98:8080 \
python3 -m arealite.launcher.slurm examples/arealite/boba.py --config examples/arealite/configs/boba.yaml \
trial_name=run0713-6
trial_name=run0713-8