AReaL/patch/sglang/v0.4.6.post4.patch

145 lines
5.7 KiB
Diff

diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):