mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
25c45c7e83
commit
88e99f887a
|
@ -8,7 +8,6 @@ import os
|
|||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# NOTE: We don't sue wildcard importing here because the type
|
||||
# `Sequence` has a very similar name to `SequenceSample`.
|
||||
|
|
|
@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
|
|||
def check_is_realhf_native_model_interface(name):
|
||||
# NOTE: we should not import iterfaces here,
|
||||
# such that we can avoid CUDA initialization.
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "reward", "fused-threading"]
|
||||
|
||||
|
||||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||
|
|
|
@ -19,7 +19,7 @@ from realhf.api.core.system_api import ExperimentConfig
|
|||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
from realhf.experiments.common.utils import (
|
||||
asdict,
|
||||
|
@ -196,8 +196,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
|
@ -300,7 +300,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
)
|
||||
critic_interface.args.pop("eps_clip")
|
||||
rw_interface = ModelInterfaceAbstraction(
|
||||
"reward",
|
||||
"rw-math-code",
|
||||
args=dict(
|
||||
dataset_path=self.dataset.path,
|
||||
tokenizer_path=self.actor.path,
|
||||
|
|
|
@ -461,4 +461,4 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
return data
|
||||
|
||||
|
||||
model_api.register_interface("reward", MultiTaskRewardInterface)
|
||||
model_api.register_interface("rw-math-code", MultiTaskRewardInterface)
|
Loading…
Reference in New Issue