AReaL/arealite/api/reward_api.py

24 lines
1010 B
Python

from typing import List
def reward_fn(
prompt: str,
completions: str,
prompt_ids: List[int],
completion_ids: List[int],
**kwargs,
):
"""This function is a placeholder for the reward function that will be used in the RLVR pipeline.
In general, there's no restriction on the signature and implementation of this function in customized rollout workflows.
It would be convinent to follow this signature and directly use it in our predefined rollout workflows.
:param prompt: The string representing the task to be completed.
:param completions: The string representing the trajectory generated by the model.
:param prompt_ids: The token IDs of the prompt.
:param completion_ids: The token IDs of the trajectory generated by the model.
:param kwargs: Other attributes of the data in the dataset, such as solutions, input_outputs, etc.
Any other attributes in the dataset will be passed as keyword arguments to this function.
:rtype: float
"""