PullRequest: 25 fix the save version in rw interface

Merge branch rw_save_version of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/25

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* fix the save version in rw interface
* format
This commit is contained in:
温差 2025-03-11 16:30:03 +08:00
parent ffe1cd71ea
commit 480f0ef502
1 changed files with 89 additions and 96 deletions

View File

@ -179,8 +179,6 @@ class PackedMathRewardInterface(model_api.ModelInterface):
group_size: int = 1
check_verifier_status: bool = False
_call_count: int = 0
def __post_init__(self):
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
@ -191,20 +189,6 @@ class PackedMathRewardInterface(model_api.ModelInterface):
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
while True:
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
if os.path.exists(gen_file_path):
self._call_count += 1
else:
break
logger.info(f"call_count: {self._call_count}")
def inference(
self,
model: model_api.Model,
@ -270,92 +254,98 @@ class PackedMathRewardInterface(model_api.ModelInterface):
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
if constants.is_last_pipe_stage():
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{model.version.global_step}r{dist.get_rank()}.txt",
)
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
+ "\n"
)
_f.write(info + "\n")
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
+ "\n"
)
self._call_count += 1
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
model.inc_version()
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
@ -405,6 +395,9 @@ class PackedMathRewardInterface(model_api.ModelInterface):
raise MathVerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
if not constants.is_last_pipe_stage():
return None
return res