mirror of https://github.com/inclusionAI/AReaL
145 lines
5.7 KiB
Diff
145 lines
5.7 KiB
Diff
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
|
index 174656b2..33fe0a5f 100644
|
|
--- a/python/sglang/srt/managers/io_struct.py
|
|
+++ b/python/sglang/srt/managers/io_struct.py
|
|
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
|
success: bool
|
|
|
|
|
|
+@dataclass
|
|
+class InterruptAllReqInput:
|
|
+ pass
|
|
+
|
|
+
|
|
+@dataclass
|
|
+class InterruptAllReqOutput:
|
|
+ num_interrupted_requests: int
|
|
+
|
|
+
|
|
@dataclass
|
|
class UpdateWeightFromDiskReqInput:
|
|
# The model path with the new weights
|
|
model_path: str
|
|
+ allow_interrupt: bool = False
|
|
# The format to load the weights
|
|
load_format: Optional[str] = None
|
|
|
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
|
index 8891115c..843a8a82 100644
|
|
--- a/python/sglang/srt/managers/scheduler.py
|
|
+++ b/python/sglang/srt/managers/scheduler.py
|
|
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
|
|
HealthCheckOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
+ InterruptAllReqInput,
|
|
+ InterruptAllReqOutput,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
ProfileReq,
|
|
@@ -419,6 +421,7 @@ class Scheduler(
|
|
# Init request dispatcher
|
|
self._request_dispatcher = TypeBasedDispatcher(
|
|
[
|
|
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
|
@@ -1938,6 +1941,15 @@ class Scheduler(
|
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
raise NotImplementedError()
|
|
|
|
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
|
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
|
+ for req in self.waiting_queue:
|
|
+ req.sampling_params.max_new_tokens = 0
|
|
+ for req in self.running_batch.reqs:
|
|
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
|
+ logger.info(f"Interrupt {num} requests.")
|
|
+ return InterruptAllReqOutput(num)
|
|
+
|
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
"""In-place update of the weights from disk."""
|
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
|
index 82709b09..bfab3ce7 100644
|
|
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
|
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
|
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
|
HealthCheckOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
+ InterruptAllReqInput,
|
|
+ InterruptAllReqOutput,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
ProfileReq,
|
|
@@ -265,6 +267,9 @@ class TokenizerManager:
|
|
self.resume_memory_occupation_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
+ self.interrupt_requests_communicator = _Communicator(
|
|
+ self.send_to_scheduler, server_args.dp_size
|
|
+ )
|
|
self.flush_cache_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
@@ -294,6 +299,10 @@ class TokenizerManager:
|
|
UpdateWeightFromDiskReqOutput,
|
|
self._handle_update_weights_from_disk_req_output,
|
|
),
|
|
+ (
|
|
+ InterruptAllReqOutput,
|
|
+ self.interrupt_requests_communicator.handle_recv,
|
|
+ ),
|
|
(
|
|
InitWeightsUpdateGroupReqOutput,
|
|
self.init_weights_update_group_communicator.handle_recv,
|
|
@@ -767,6 +776,13 @@ class TokenizerManager:
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
|
|
+ if obj.allow_interrupt:
|
|
+ num_interrupted_requests = await self.interrupt_all_requests(
|
|
+ InterruptAllReqInput()
|
|
+ )
|
|
+ # Set a break point to wait for the interrupt to finish
|
|
+ await asyncio.sleep(0.1)
|
|
+
|
|
# default the load format to the server_args
|
|
if obj.load_format is None:
|
|
obj.load_format = self.server_args.load_format
|
|
@@ -776,7 +792,12 @@ class TokenizerManager:
|
|
# Hold the lock if it is not async. This means that weight sync
|
|
# cannot run while requests are in progress.
|
|
async with self.model_update_lock.writer_lock:
|
|
- return await self._wait_for_model_update_from_disk(obj)
|
|
+ success, message, n_paused = (
|
|
+ await self._wait_for_model_update_from_disk(obj)
|
|
+ )
|
|
+ if obj.allow_interrupt:
|
|
+ return success, message, num_interrupted_requests
|
|
+ return success, message, n_paused
|
|
|
|
async def _wait_for_model_update_from_disk(
|
|
self, obj: UpdateWeightFromDiskReqInput
|
|
@@ -849,6 +870,18 @@ class TokenizerManager:
|
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
+ async def interrupt_all_requests(
|
|
+ self,
|
|
+ obj: InterruptAllReqInput,
|
|
+ request: Optional[fastapi.Request] = None,
|
|
+ ) -> Tuple[bool, str]:
|
|
+ self.auto_create_handle_loop()
|
|
+ result = await self.interrupt_requests_communicator(obj)
|
|
+ if self.server_args.dp_size == 1:
|
|
+ return result[0].num_interrupted_requests
|
|
+ else:
|
|
+ return [r.num_interrupted_requests for r in result]
|
|
+
|
|
async def get_weights_by_name(
|
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
):
|