mirror of https://github.com/inclusionAI/AReaL
fix
This commit is contained in:
parent
037adedc70
commit
938c06a652
|
@ -33,7 +33,9 @@ from arealite.utils.evaluator import Evaluator
|
|||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
from realhf.base import logging, stats_tracker
|
||||
|
||||
logger = logging.getLogger("boba math")
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
@ -98,11 +100,33 @@ def get_boba_math_dataset(tokenizer, rank, world_size):
|
|||
def boba_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
||||
):
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
jobs = []
|
||||
with ProcessPool(max_workers=1) as executor:
|
||||
for sol in solutions:
|
||||
job = executor.schedule(
|
||||
process_results, args=[completions, sol], timeout=15
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
label = 0
|
||||
for sol in solutions:
|
||||
label = label or process_results(completions, sol)[0]
|
||||
for job in jobs:
|
||||
try:
|
||||
x = job.result()
|
||||
except TimeoutError:
|
||||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
return label
|
||||
|
||||
|
||||
|
@ -180,6 +204,8 @@ def main_grpo():
|
|||
batch = rollout.rollout(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
|
@ -189,19 +215,27 @@ def main_grpo():
|
|||
else:
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("Recompute logp")
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
log_gpu_stats("Ref logp")
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("PPO update")
|
||||
|
|
|
@ -24,7 +24,7 @@ rollout:
|
|||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: false
|
||||
enable_rollout_tracing: true
|
||||
|
||||
gconfig:
|
||||
n_samples: 16
|
||||
|
|
Loading…
Reference in New Issue