mirror of https://github.com/inclusionAI/AReaL
31 lines
798 B
Python
31 lines
798 B
Python
from typing import Callable
|
|
|
|
from arealite.api.cli_args import EvaluatorConfig
|
|
from arealite.api.io_struct import FinetuneSpec
|
|
from realhf.base import timeutil
|
|
|
|
|
|
class Evaluator:
|
|
|
|
def __init__(self, config: EvaluatorConfig, ft_spec: FinetuneSpec):
|
|
self.config = config
|
|
self.ft_sepc = ft_spec
|
|
self.freq_ctl = timeutil.EpochStepTimeFreqCtl(
|
|
freq_epoch=config.freq_epochs,
|
|
freq_step=config.freq_steps,
|
|
freq_sec=config.freq_secs,
|
|
)
|
|
|
|
def evaluate(
|
|
self,
|
|
evaluate_fn: Callable,
|
|
epoch: int,
|
|
step: int,
|
|
global_step: int,
|
|
):
|
|
if not self.freq_ctl.check(
|
|
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
|
):
|
|
return
|
|
evaluate_fn()
|