mirror of https://github.com/inclusionAI/AReaL
PullRequest: 25 fix the save version in rw interface
Merge branch rw_save_version of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/25 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * fix the save version in rw interface * format
This commit is contained in:
parent
ffe1cd71ea
commit
480f0ef502
|
@ -179,8 +179,6 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
group_size: int = 1
|
||||
check_verifier_status: bool = False
|
||||
|
||||
_call_count: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
||||
|
@ -191,20 +189,6 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
logger.info(f"rw_type: {self.rw_type}")
|
||||
logger.info(f"post_process: {self.post_process}")
|
||||
|
||||
while True:
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
if os.path.exists(gen_file_path):
|
||||
self._call_count += 1
|
||||
else:
|
||||
break
|
||||
logger.info(f"call_count: {self._call_count}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
|
@ -270,92 +254,98 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
if constants.is_last_pipe_stage():
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||
)
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
self._call_count += 1
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
model.inc_version()
|
||||
|
||||
if scores.dtype != torch.float32:
|
||||
scores = scores.to(torch.float32)
|
||||
|
@ -405,6 +395,9 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
raise MathVerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
|
||||
if not constants.is_last_pipe_stage():
|
||||
return None
|
||||
return res
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue