mirror of https://github.com/inclusionAI/AReaL
PullRequest: 35 Support log probability recomputation in PPO.
Merge branch fw/recompute-logprob of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/35 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * .
This commit is contained in:
parent
c074d6142b
commit
69681d9fe4
|
@ -2,9 +2,11 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from typing import List
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
from packaging.version import Version
|
||||
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
|
||||
from realhf.base import logging
|
||||
|
@ -27,7 +29,10 @@ def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation
|
|||
if vllm.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
|
||||
logger.warning("vLLM hybrid_train is enabled, but no training RPCs are found.")
|
||||
if vllm.hybrid_train and not vllm.enforce_eager:
|
||||
raise ValueError("vLLM hybrid_train requires eager mode to be enabled.")
|
||||
logger.warning(
|
||||
"For version < 0.7.0, vLLM hybrid_train requires eager mode to be enabled. "
|
||||
"The user has the responsibility to ensure the version is correct."
|
||||
)
|
||||
|
||||
|
||||
def check_valid_optimizer(model: ModelTrainEvalConfig):
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
|
@ -69,7 +70,7 @@ import realhf.api.from_hf # isort:skip
|
|||
|
||||
logger = logging.getLogger("CommonExperimentConfig", "colored")
|
||||
|
||||
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
|
||||
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -108,7 +109,7 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
|
||||
|
||||
- A regex pattern like ``vllm.{IdentPara}+{IdentPara}``\: Decoupled generation (vLLM) and training allocations with correspnding identical parallelization strategies. Note that the pipeline parallel degree of vLLM can only be 1.
|
||||
- A regex pattern like ``{vllm|sglang}.{IdentPara}+{IdentPara}``\: Decoupled generation and training allocations with correspnding identical parallelization strategies.
|
||||
|
||||
- Key-value pairs with MFC names and their parallel strategies in the whole cluster, e.g., ``actor_gen:d4m2p1,*:d2p2m2`` specifies a ``d4m2p1`` strategy for actor geneartion and ``d2p2m2`` for other MFCs in a world of 8 GPUs.
|
||||
|
||||
|
@ -373,7 +374,7 @@ class CommonExperimentConfig(Experiment):
|
|||
f"and n_gpus_per_node {self.n_gpus_per_node}."
|
||||
)
|
||||
|
||||
self.__check_legal_allocation_options()
|
||||
self._check_legal_allocation_options()
|
||||
|
||||
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
|
||||
|
||||
|
@ -452,9 +453,9 @@ class CommonExperimentConfig(Experiment):
|
|||
raise ValueError(
|
||||
"The multiplication of 3D parallel degrees "
|
||||
"does not equal to the number of gpus. "
|
||||
"Note that the device mesh of vLLM should be disjoint from the device mesh of other MFCs, "
|
||||
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
|
||||
"so their summation should be equal to the total number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, vllm.dp={gdp}, vllm.pp={gpp}, vllm.mp={gmp}, "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
alloc = RPCAllocation(
|
||||
|
@ -535,7 +536,7 @@ class CommonExperimentConfig(Experiment):
|
|||
def _get_model_worker_configs(
|
||||
self, rpc_allocs: List[RPCAllocation]
|
||||
) -> List[ModelWorker]:
|
||||
self.__run_model_sanity_check(rpc_allocs)
|
||||
self._run_model_sanity_check(rpc_allocs)
|
||||
|
||||
model_worker = []
|
||||
shard_counter = defaultdict(lambda: 0)
|
||||
|
@ -557,7 +558,7 @@ class CommonExperimentConfig(Experiment):
|
|||
tokenizer_name_or_path=self.tokenizer_name_or_path,
|
||||
)
|
||||
|
||||
# vLLM enabled model worker, shortcut case
|
||||
# decoupled allocation, shortcut case
|
||||
if (
|
||||
self._allocation_mode.is_decoupled()
|
||||
and self.gen_device_mesh.mapping[i, j]
|
||||
|
@ -574,22 +575,30 @@ class CommonExperimentConfig(Experiment):
|
|||
is_train=False,
|
||||
)
|
||||
model_cfg = self.models[model_name.role]
|
||||
global vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
|
||||
gen_backend_name = ""
|
||||
if self._allocation_mode.is_decoupled_vllm():
|
||||
gen_backend_name = "vllm"
|
||||
elif self._allocation_mode.is_decoupled_sglang():
|
||||
gen_backend_name = "sglang"
|
||||
backend_cfg = getattr(model_cfg, gen_backend_name)
|
||||
|
||||
global GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
if (
|
||||
model_cfg.vllm.hybrid_train
|
||||
and not vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
backend_cfg.hybrid_train
|
||||
and not GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN
|
||||
):
|
||||
logger.warning(
|
||||
"vLLM hybrid_train=True takes no effect for the decoupled allocation"
|
||||
"hybrid_train=True takes no effect for the decoupled allocation"
|
||||
)
|
||||
vLLM_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
|
||||
model_cfg.vllm.hybrid_train = False
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = True
|
||||
backend_cfg.hybrid_train = False
|
||||
|
||||
if gen_backend_name == "vllm":
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
|
||||
shard_idx = shard_counter[model_name]
|
||||
vllm_dict_args: Dict[str, Any] = OmegaConf.to_container(
|
||||
model_cfg.vllm, resolve=True
|
||||
)
|
||||
dict_args: Dict[str, Any] = asdict(backend_cfg)
|
||||
mw.shards.append(
|
||||
StandaloneModelShardAbstraction(
|
||||
id=ModelShardID(
|
||||
|
@ -603,11 +612,11 @@ class CommonExperimentConfig(Experiment):
|
|||
"tokenizer", args=dict(tokenizer_path=model_cfg.path)
|
||||
),
|
||||
backend=ModelBackendAbstraction(
|
||||
"vllm",
|
||||
gen_backend_name,
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
dtype="bfloat16" if model_cfg.bf16 else "float16",
|
||||
**vllm_dict_args,
|
||||
**dict_args,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -684,12 +693,13 @@ class CommonExperimentConfig(Experiment):
|
|||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
|
||||
vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
backend = ModelBackendAbstraction(
|
||||
"vllm",
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
**vllm_dict_args,
|
||||
**dict_args,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -698,7 +708,6 @@ class CommonExperimentConfig(Experiment):
|
|||
"vllm",
|
||||
"sglang",
|
||||
]:
|
||||
print(rpcs, model_name, backend.type_)
|
||||
raise ValueError(
|
||||
"vLLM or SGLang is not enabled for generation. "
|
||||
"This behavior has been deprecated. "
|
||||
|
@ -706,7 +715,6 @@ class CommonExperimentConfig(Experiment):
|
|||
"or model.sglang.hybrid_train=True."
|
||||
)
|
||||
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
if mapping[i, j]:
|
||||
shard_idx = shard_counter[model_name]
|
||||
mw.shards.append(
|
||||
|
@ -749,7 +757,7 @@ class CommonExperimentConfig(Experiment):
|
|||
evaluator=self.auto_eval_config,
|
||||
)
|
||||
|
||||
def __check_legal_allocation_options(self):
|
||||
def _check_legal_allocation_options(self):
|
||||
if self.n_nodes > 1 and self.mode == "local":
|
||||
raise ValueError(
|
||||
"Cannot run multi-node experiment in local mode, "
|
||||
|
@ -791,7 +799,7 @@ class CommonExperimentConfig(Experiment):
|
|||
f"RPC {rpc.name} model name {rpc.model_name.role} is not in models."
|
||||
)
|
||||
|
||||
def __run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
|
||||
def _run_model_sanity_check(self, rpc_allocs: List[RPCAllocation]):
|
||||
for alloc in rpc_allocs:
|
||||
check_valid_parallel_batch_size(alloc)
|
||||
for role, model in self.models.items():
|
||||
|
|
|
@ -90,6 +90,7 @@ class PPOHyperparameters:
|
|||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
disable_value: bool = False
|
||||
recompute_logprob: bool = False
|
||||
max_reward_clip: float = 20.0
|
||||
reward_output_scaling: float = 1.0
|
||||
reward_output_bias: float = 0.0
|
||||
|
@ -198,6 +199,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptOnlyDatasetConfig
|
||||
|
@ -336,6 +338,14 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
+ 128,
|
||||
),
|
||||
)
|
||||
rollout_output_keys = [
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
]
|
||||
if self.ppo.recompute_logprob:
|
||||
rollout_output_keys.remove("packed_logprobs")
|
||||
rollout = MFCDef(
|
||||
name="actor_gen",
|
||||
model_name="actor",
|
||||
|
@ -345,12 +355,21 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts"],
|
||||
output_keys=[
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
],
|
||||
output_keys=rollout_output_keys,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
actor_inf = MFCDef(
|
||||
name="actor_inf",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_input_ids"],
|
||||
output_keys=["packed_logprobs"],
|
||||
output_key_remap=dict(logprobs="packed_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -380,6 +399,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=["packed_ref_logprobs"],
|
||||
output_key_remap=dict(logprobs="packed_ref_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -445,45 +465,40 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
rpcs = {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
"critic_inf": inf_values,
|
||||
"critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"actor_inf": actor_inf,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
# "critic_inf": inf_values,
|
||||
# "critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
"critic_inf": inf_values,
|
||||
"critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
rpcs.pop("critic_inf")
|
||||
rpcs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
rpcs.pop("actor_inf")
|
||||
return rpcs
|
||||
|
||||
@property
|
||||
def allocations(self):
|
||||
allocs = {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
"critic_inf": self.critic_inf,
|
||||
"critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"actor_inf": self.actor_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
# "critic_inf": self.critic_inf,
|
||||
# "critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
"critic_inf": self.critic_inf,
|
||||
"critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
allocs.pop("critic_inf")
|
||||
allocs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
allocs.pop("actor_inf")
|
||||
return allocs
|
||||
|
||||
@property
|
||||
def datasets(self):
|
||||
|
|
|
@ -110,32 +110,40 @@ def make_inf_backend_config(
|
|||
def resolve_replica_ids(
|
||||
rpc_allocs: List[RPCAllocation], models: Dict[str, ModelTrainEvalConfig]
|
||||
):
|
||||
role_cnt = collections.defaultdict(int)
|
||||
first_device_mesh = dict()
|
||||
first_parallel = dict()
|
||||
first_rpc = dict()
|
||||
role_rpcs = collections.defaultdict(list)
|
||||
for alloc in rpc_allocs:
|
||||
rpc = alloc.rpc
|
||||
if rpc.role not in first_device_mesh:
|
||||
first_device_mesh[rpc.role] = alloc.device_mesh
|
||||
first_parallel[rpc.role] = alloc.parallel
|
||||
first_rpc[rpc.role] = rpc
|
||||
role_rpcs[rpc.role].append(alloc)
|
||||
|
||||
for role, allocs in role_rpcs.items():
|
||||
cnt = len(allocs)
|
||||
if cnt == 1:
|
||||
allocs[0].rpc.model_name = ModelName(role, 0)
|
||||
continue
|
||||
model_cfg = models[rpc.role]
|
||||
if (rpc.is_train() and first_rpc[rpc.role].is_generate()) or (
|
||||
rpc.is_generate() and first_rpc[rpc.role].is_train()
|
||||
):
|
||||
if model_cfg.vllm.hybrid_train:
|
||||
role_cnt[rpc.role] += 1
|
||||
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
|
||||
rpcs = [alloc.rpc for alloc in allocs]
|
||||
if any(rpc.is_train() for rpc in rpcs):
|
||||
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_train())
|
||||
elif any(rpc.is_inference() for rpc in rpcs):
|
||||
main_alloc = next(alloc for alloc in allocs if alloc.rpc.is_inference())
|
||||
else:
|
||||
main_alloc = allocs[0]
|
||||
main_alloc.rpc.model_name = ModelName(role, 0)
|
||||
i = 1
|
||||
for alloc in allocs:
|
||||
if alloc.rpc.name == main_alloc.rpc.name:
|
||||
continue
|
||||
if alloc.device_mesh != first_device_mesh[rpc.role] or not parallelism_eq(
|
||||
alloc.parallel, first_parallel[rpc.role]
|
||||
):
|
||||
role_cnt[rpc.role] += 1
|
||||
rpc.model_name = ModelName(rpc.role, role_cnt[rpc.role])
|
||||
continue
|
||||
assert rpc.model_name.replica_id == 0
|
||||
same_alloc = alloc.device_mesh == main_alloc.device_mesh and parallelism_eq(
|
||||
alloc.parallel, main_alloc.parallel
|
||||
)
|
||||
if not same_alloc or (
|
||||
alloc.rpc.is_generate()
|
||||
and main_alloc.rpc.is_train()
|
||||
and (models[role].vllm.hybrid_train)
|
||||
):
|
||||
alloc.rpc.model_name = ModelName(role, i)
|
||||
i += 1
|
||||
else:
|
||||
alloc.rpc.model_name = ModelName(role, 0)
|
||||
|
||||
|
||||
def resolve_rpc_hooks(
|
||||
|
@ -207,6 +215,7 @@ class AllocationType(enum.Enum):
|
|||
HEURISTIC = 4
|
||||
SEARCH = 5
|
||||
DECOUPLED_SGLANG = 6
|
||||
DECOUPLED_MOCK = 7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -218,6 +227,7 @@ class AllocationMode:
|
|||
return self.type_ in [
|
||||
AllocationType.DECOUPLED_vLLM,
|
||||
AllocationType.DECOUPLED_SGLANG,
|
||||
AllocationType.DECOUPLED_MOCK,
|
||||
]
|
||||
|
||||
def is_decoupled_vllm(self):
|
||||
|
@ -226,6 +236,9 @@ class AllocationMode:
|
|||
def is_decoupled_sglang(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_SGLANG
|
||||
|
||||
def is_decoupled_mock(self):
|
||||
return self.type_ == AllocationType.DECOUPLED_MOCK
|
||||
|
||||
def is_global_hybrid(self):
|
||||
return self.type_ == AllocationType.GLOBAL_HYBRID
|
||||
|
||||
|
@ -246,6 +259,8 @@ class AllocationMode:
|
|||
return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled)
|
||||
elif "sglang" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled)
|
||||
elif "mock" in allocation_mode:
|
||||
return cls(AllocationType.DECOUPLED_MOCK, alloc_decoupled)
|
||||
if alloc_3d:
|
||||
return cls(AllocationType.GLOBAL_HYBRID, alloc_3d)
|
||||
if alloc_hybrid:
|
||||
|
@ -272,7 +287,7 @@ class AllocationMode:
|
|||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
|
||||
pattern = re.compile(
|
||||
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
|
||||
r"(?:(?:vllm|sglang|mock)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang|mock)\.(.+))"
|
||||
)
|
||||
m = pattern.match(allocation_mode)
|
||||
if not m:
|
||||
|
|
|
@ -834,6 +834,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
self,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
):
|
||||
|
@ -841,6 +842,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
post_hook=post_hook,
|
||||
output_seqlens=output_seqlens,
|
||||
aggregate_fn=aggregate_fn,
|
||||
)
|
||||
|
||||
|
|
|
@ -179,12 +179,14 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
self,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
):
|
||||
return self.inf_engine.forward(
|
||||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
output_seqlens=output_seqlens,
|
||||
post_hook=post_hook,
|
||||
aggregate_fn=aggregate_fn,
|
||||
)
|
||||
|
|
|
@ -444,13 +444,13 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["packed_ref_logprobs"],
|
||||
keys=["logprobs"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||
data=dict(packed_ref_logprobs=logprobs),
|
||||
dtypes=dict(logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(logprobs=()),
|
||||
data=dict(logprobs=logprobs),
|
||||
seqlens=dict(
|
||||
packed_ref_logprobs=[
|
||||
logprobs=[
|
||||
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
||||
]
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue