PullRequest: 42 Fix the error raising logic and the bug when using sglang with tensor parallel

Merge branch fw/patch20250317-3 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/42

Signed-off-by: 郭唯 <kira.gw@antgroup.com>


* .
This commit is contained in:
博惟 2025-03-17 20:37:43 +08:00
parent 312e84b62c
commit b619f64cda
3 changed files with 10 additions and 4 deletions

View File

@ -365,6 +365,10 @@ class CommonExperimentConfig(Experiment):
),
)
@property
def _allocation_mode(self):
return AllocationMode.from_str(self.allocation_mode)
def _get_rpc_allocations(self) -> List[RPCAllocation]:
if self.allocation_mode == "manual" and self.nodelist is None:
logger.warning(
@ -377,8 +381,6 @@ class CommonExperimentConfig(Experiment):
self._check_legal_allocation_options()
self._allocation_mode = AllocationMode.from_str(self.allocation_mode)
rpcs = self.rpcs
if self.allocation_mode == "search":
# assert self.mode == "slurm"

View File

@ -268,12 +268,15 @@ class PPOMATHConfig(CommonExperimentConfig):
@property
def rpcs(self):
if (
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
(self._allocation_mode.is_decoupled_vllm() or self.actor.vllm.hybrid_train)
and self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
> self.actor.vllm.max_seq_len_to_capture
and not self.actor.vllm.enforce_eager
):
raise RuntimeError(
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
f"smaller than the prompt length + generation length "
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
raise RuntimeError(

View File

@ -369,6 +369,7 @@ class SGLangGenerationEngine(PipelinableEngine):
def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
return
async def _fn():
async with SGLangAPIClient(