This commit is contained in:
bowei.fw 2025-03-22 12:33:17 +08:00
parent 25c45c7e83
commit 88e99f887a
4 changed files with 5 additions and 6 deletions

View File

@ -8,7 +8,6 @@ import os
import random
import time
from contextlib import contextmanager
from enum import Enum
# NOTE: We don't sue wildcard importing here because the type
# `Sequence` has a very similar name to `SequenceSample`.

View File

@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
return name in ["ppo_actor", "ppo_critic", "sft", "reward", "fused-threading"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):

View File

@ -19,7 +19,7 @@ from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import (
asdict,
@ -196,8 +196,8 @@ class PPOMATHConfig(CommonExperimentConfig):
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
@ -300,7 +300,7 @@ class PPOMATHConfig(CommonExperimentConfig):
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"reward",
"rw-math-code",
args=dict(
dataset_path=self.dataset.path,
tokenizer_path=self.actor.path,

View File

@ -461,4 +461,4 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
return data
model_api.register_interface("reward", MultiTaskRewardInterface)
model_api.register_interface("rw-math-code", MultiTaskRewardInterface)