mirror of https://github.com/inclusionAI/AReaL
PullRequest: 42 Fix the error raising logic and the bug when using sglang with tensor parallel
Merge branch fw/patch20250317-3 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/42 Signed-off-by: 郭唯 <kira.gw@antgroup.com> * .
This commit is contained in:
parent
312e84b62c
commit
b619f64cda
|
@ -365,6 +365,10 @@ class CommonExperimentConfig(Experiment):
|
|||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def _allocation_mode(self):
|
||||
return AllocationMode.from_str(self.allocation_mode)
|
||||
|
||||
def _get_rpc_allocations(self) -> List[RPCAllocation]:
|
||||
if self.allocation_mode == "manual" and self.nodelist is None:
|
||||
logger.warning(
|
||||
|
@ -377,8 +381,6 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
self._check_legal_allocation_options()
|
||||
|
||||
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
|
||||
|
||||
rpcs = self.rpcs
|
||||
if self.allocation_mode == "search":
|
||||
# assert self.mode == "slurm"
|
||||
|
|
|
@ -268,12 +268,15 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
@property
|
||||
def rpcs(self):
|
||||
if (
|
||||
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
||||
(self._allocation_mode.is_decoupled_vllm() or self.actor.vllm.hybrid_train)
|
||||
and self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
||||
> self.actor.vllm.max_seq_len_to_capture
|
||||
and not self.actor.vllm.enforce_eager
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
|
||||
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
f"smaller than the prompt length + generation length "
|
||||
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
)
|
||||
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -369,6 +369,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
def update_weights_from_disk(self, path):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
return
|
||||
|
||||
async def _fn():
|
||||
async with SGLangAPIClient(
|
||||
|
|
Loading…
Reference in New Issue