mirror of https://github.com/inclusionAI/AReaL
643 lines
26 KiB
Diff
643 lines
26 KiB
Diff
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
|
|
index 60091b9a..7a1c856b 100644
|
|
--- a/python/sglang/srt/layers/logits_processor.py
|
|
+++ b/python/sglang/srt/layers/logits_processor.py
|
|
@@ -15,6 +15,7 @@
|
|
|
|
import dataclasses
|
|
import logging
|
|
+import os
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
)
|
|
from sglang.srt.utils import dump_to_file
|
|
|
|
+# When compute the input and output tokens logprobs, if the rows of the
|
|
+# logprobs are too large, the peak memory usage will be too high. For example,
|
|
+# if the logprobs are [10000, 150000], the peak memory usage will be greater
|
|
+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float)
|
|
+# So we split the logprobs into multiple chunks.
|
|
+LOGITS_PROCESSER_CHUNK_SIZE = int(
|
|
+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048")
|
|
+)
|
|
+
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module):
|
|
input_logprob_indices = None
|
|
else:
|
|
# Input logprobs are required.
|
|
- # Find 3 different indices.
|
|
+ # Find 4 different indices.
|
|
# 1. pruned_states: hidden states that we want logprobs from.
|
|
# 2. sample_indices: Indices that have sampled tokens.
|
|
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
|
+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices.
|
|
sample_index_pt = -1
|
|
sample_indices = []
|
|
input_logprob_indices_pt = 0
|
|
input_logprob_indices = []
|
|
pt, pruned_states = 0, []
|
|
+ idx, sequence_index_mapping = 0, []
|
|
for extend_logprob_start_len, extend_len in zip(
|
|
logits_metadata.extend_logprob_start_lens_cpu,
|
|
logits_metadata.extend_seq_lens_cpu,
|
|
@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module):
|
|
# by a caller.
|
|
assert extend_len > start_len
|
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
|
+ # sequence_index_mapping, repeat this for loop index
|
|
+ sequence_index_mapping.extend([idx] * (extend_len - start_len))
|
|
+ idx += 1
|
|
pt += extend_len
|
|
+
|
|
sample_index_pt += extend_len - start_len
|
|
sample_indices.append(sample_index_pt)
|
|
input_logprob_indices.extend(
|
|
@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module):
|
|
)
|
|
input_logprob_indices_pt += extend_len - start_len
|
|
|
|
+ sequence_index_mapping.append(idx - 1)
|
|
pruned_states = torch.cat(pruned_states)
|
|
sample_indices = torch.tensor(
|
|
sample_indices, device=pruned_states.device, dtype=torch.int64
|
|
@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module):
|
|
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
|
)
|
|
|
|
- # Compute logits for both input and sampled tokens.
|
|
- logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
|
- sampled_logits = (
|
|
- logits[sample_indices] if sample_indices is not None else logits
|
|
- )
|
|
-
|
|
if self.debug_tensor_dump_output_folder:
|
|
assert (
|
|
not self.do_tensor_parallel_all_gather
|
|
@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module):
|
|
else:
|
|
assert False, "Should never reach"
|
|
|
|
+ del hidden_states
|
|
+
|
|
if not logits_metadata.extend_return_logprob:
|
|
+ # Compute logits for both input and sampled tokens.
|
|
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
|
+ sampled_logits = (
|
|
+ logits[sample_indices] if sample_indices is not None else logits
|
|
+ )
|
|
+
|
|
# Decode mode or extend mode without return_logprob.
|
|
return LogitsProcessorOutput(
|
|
next_token_logits=sampled_logits,
|
|
hidden_states=hidden_states_to_store,
|
|
)
|
|
else:
|
|
- input_logprobs = logits[input_logprob_indices]
|
|
- del hidden_states, logits
|
|
+ # Compute logprobs requires lot of memory, so we split pruned_states
|
|
+ # into chunks of rows to compute input_logprobs separately, then
|
|
+ # concatenate the results.
|
|
+ return self._compute_output_by_chunk(
|
|
+ pruned_states,
|
|
+ sample_indices,
|
|
+ hidden_states_to_store,
|
|
+ input_logprob_indices,
|
|
+ sequence_index_mapping,
|
|
+ lm_head,
|
|
+ logits_metadata,
|
|
+ )
|
|
|
|
- # Normalize the logprob w/o temperature, top-p
|
|
- pruned_lens = torch.tensor(
|
|
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
- device=input_logprobs.device,
|
|
+ def _compute_output_by_chunk(
|
|
+ self,
|
|
+ pruned_states: torch.Tensor,
|
|
+ sample_indices: torch.Tensor,
|
|
+ hidden_states_to_store: Optional[torch.Tensor],
|
|
+ input_logprob_indices: torch.Tensor,
|
|
+ index_mapping: list[int],
|
|
+ lm_head: VocabParallelEmbedding,
|
|
+ logits_metadata: LogitsMetadata,
|
|
+ ) -> LogitsProcessorOutput:
|
|
+ """
|
|
+ compute logprobs for the output token from the hidden states.
|
|
+ To avoid using too much memory, we split pruned_states into chunks of
|
|
+ rows to compute input_logprobs separately, then concatenate the results.
|
|
+
|
|
+ Returns:
|
|
+ LogitsProcessorOutput: logits processor output class
|
|
+ """
|
|
+
|
|
+ # Normalize the logprob w/o temperature, top-p
|
|
+ pruned_lens = torch.tensor(
|
|
+ logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
+ device=pruned_states.device,
|
|
+ )
|
|
+ if logits_metadata.temp_scaled_logprobs:
|
|
+ logits_metadata.temperature = torch.repeat_interleave(
|
|
+ logits_metadata.temperature.view(-1),
|
|
+ pruned_lens,
|
|
+ ).view(-1, 1)
|
|
+ if logits_metadata.top_p_normalized_logprobs:
|
|
+ logits_metadata.top_p = torch.repeat_interleave(
|
|
+ logits_metadata.top_p,
|
|
+ pruned_lens,
|
|
)
|
|
- if logits_metadata.temp_scaled_logprobs:
|
|
- logits_metadata.temperature = torch.repeat_interleave(
|
|
- logits_metadata.temperature.view(-1),
|
|
- pruned_lens,
|
|
- ).view(-1, 1)
|
|
- if logits_metadata.top_p_normalized_logprobs:
|
|
- logits_metadata.top_p = torch.repeat_interleave(
|
|
- logits_metadata.top_p,
|
|
- pruned_lens,
|
|
- )
|
|
- input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
|
- input_logprobs, logits_metadata
|
|
+
|
|
+ # The peak memory usage is proportional to the chunk size.
|
|
+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE
|
|
+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size
|
|
+
|
|
+ input_token_logprobs = []
|
|
+ if logits_metadata.extend_return_top_logprob:
|
|
+ input_top_logprobs_val = []
|
|
+ input_top_logprobs_idx = []
|
|
+ else:
|
|
+ input_top_logprobs_val = None
|
|
+ input_top_logprobs_idx = None
|
|
+
|
|
+ if logits_metadata.extend_token_ids_logprob:
|
|
+ input_token_ids_logprobs_val = []
|
|
+ input_token_ids_logprobs_idx = []
|
|
+ else:
|
|
+ input_token_ids_logprobs_val = None
|
|
+ input_token_ids_logprobs_idx = None
|
|
+
|
|
+ # It a single sequence is split into multiple chunks, we need to keep track
|
|
+ # of the pruned length of the sequences in the previous chunks.
|
|
+ split_len_topk = 0
|
|
+ split_len_token_ids = 0
|
|
+
|
|
+ for i in range(num_chunks):
|
|
+ start_idx = i * chunk_size
|
|
+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0])
|
|
+
|
|
+ # Get indices for this chunk
|
|
+ chunk_mask = (input_logprob_indices >= start_idx) & (
|
|
+ input_logprob_indices < end_idx
|
|
+ )
|
|
+
|
|
+ global_indices = input_logprob_indices[chunk_mask]
|
|
+ chunk_indices = global_indices - start_idx
|
|
+ chunk_states = pruned_states[start_idx:end_idx]
|
|
+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
|
|
+
|
|
+ if chunk_indices.numel() == 0:
|
|
+ continue
|
|
+
|
|
+ # Compute the logprobs of the chunk
|
|
+ chunk_input_logprobs = chunk_logits[chunk_indices]
|
|
+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
|
+ chunk_input_logprobs, global_indices, logits_metadata
|
|
)
|
|
|
|
+ # For each chunk, we need to get the slice of the sequence_index_mapping
|
|
+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1)
|
|
+
|
|
# Get the logprob of top-k tokens
|
|
if logits_metadata.extend_return_top_logprob:
|
|
- (
|
|
+ split_len_topk = self.get_top_logprobs(
|
|
+ chunk_input_logprobs,
|
|
+ logits_metadata,
|
|
+ chunk_slice,
|
|
input_top_logprobs_val,
|
|
input_top_logprobs_idx,
|
|
- ) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
|
- else:
|
|
- input_top_logprobs_val = input_top_logprobs_idx = None
|
|
+ split_len_topk,
|
|
+ )
|
|
|
|
# Get the logprob of given token id
|
|
if logits_metadata.extend_token_ids_logprob:
|
|
- (
|
|
+ split_len_token_ids = self.get_token_ids_logprobs(
|
|
+ chunk_input_logprobs,
|
|
+ logits_metadata,
|
|
+ chunk_slice,
|
|
input_token_ids_logprobs_val,
|
|
input_token_ids_logprobs_idx,
|
|
- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
|
- else:
|
|
- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
|
-
|
|
- input_token_logprobs = input_logprobs[
|
|
- torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
|
- logits_metadata.extend_input_logprob_token_ids_gpu,
|
|
- ]
|
|
+ split_len_token_ids,
|
|
+ )
|
|
|
|
- return LogitsProcessorOutput(
|
|
- next_token_logits=sampled_logits,
|
|
- input_token_logprobs=input_token_logprobs,
|
|
- input_top_logprobs_val=input_top_logprobs_val,
|
|
- input_top_logprobs_idx=input_top_logprobs_idx,
|
|
- hidden_states=hidden_states_to_store,
|
|
- input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
|
- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
|
+ # Handle sampled logits for the chunk if needed
|
|
+ chunk_sample_mask = (sample_indices >= start_idx) & (
|
|
+ sample_indices < end_idx
|
|
)
|
|
+ if i == 0: # Initialize sampled_logits on first chunk
|
|
+ sampled_logits = torch.empty(
|
|
+ (sample_indices.shape[0], chunk_logits.shape[1]),
|
|
+ dtype=chunk_logits.dtype,
|
|
+ device=chunk_logits.device,
|
|
+ )
|
|
+ if chunk_sample_mask.any():
|
|
+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
|
|
+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
|
|
+
|
|
+ # Get the logprob of the requested token ids
|
|
+ chunk_input_token_logprobs = chunk_input_logprobs[
|
|
+ torch.arange(
|
|
+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
|
|
+ ),
|
|
+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx],
|
|
+ ]
|
|
+ input_token_logprobs.append(chunk_input_token_logprobs)
|
|
+
|
|
+ # Concatenate the results
|
|
+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
|
|
+
|
|
+ return LogitsProcessorOutput(
|
|
+ hidden_states=hidden_states_to_store,
|
|
+ next_token_logits=sampled_logits,
|
|
+ input_token_logprobs=input_token_logprobs,
|
|
+ input_top_logprobs_val=input_top_logprobs_val,
|
|
+ input_top_logprobs_idx=input_top_logprobs_idx,
|
|
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
|
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
|
+ )
|
|
|
|
def _get_logits(
|
|
self,
|
|
@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module):
|
|
return logits
|
|
|
|
@staticmethod
|
|
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
|
+ def get_top_logprobs(
|
|
+ logprobs: torch.Tensor,
|
|
+ logits_metadata: LogitsMetadata,
|
|
+ chunk_slice: slice,
|
|
+ input_top_logprobs_val: List,
|
|
+ input_top_logprobs_idx: List,
|
|
+ split_pruned_len: int,
|
|
+ ):
|
|
+ """Get top-k logprobs for each sequence in the chunk.
|
|
+
|
|
+ Args:
|
|
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
|
+ logits_metadata: Metadata containing top-k and pruned length info
|
|
+ chunk_slice: Slice of sequences to process
|
|
+ input_top_logprobs_val: List to store top-k logprob values
|
|
+ input_top_logprobs_idx: List to store top-k token indices
|
|
+ split_pruned_len: Length of pruned tokens from previous chunk
|
|
+
|
|
+ Returns:
|
|
+ int: Number of remaining tokens to process in next chunk
|
|
+ """
|
|
+
|
|
max_k = max(logits_metadata.top_logprobs_nums)
|
|
- ret = all_logprobs.topk(max_k, dim=1)
|
|
+ ret = logprobs.topk(max_k, dim=1)
|
|
values = ret.values.tolist()
|
|
indices = ret.indices.tolist()
|
|
|
|
- input_top_logprobs_val, input_top_logprobs_idx = [], []
|
|
-
|
|
pt = 0
|
|
- for k, pruned_len in zip(
|
|
- logits_metadata.top_logprobs_nums,
|
|
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
- ):
|
|
+ next_split_pruned_len = 0
|
|
+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
|
|
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
|
+
|
|
+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
|
|
+ # Adjust pruned length for first sequence
|
|
+ if n == 0:
|
|
+ pruned_len -= split_pruned_len
|
|
+ else:
|
|
+ split_pruned_len = 0
|
|
+
|
|
if pruned_len <= 0:
|
|
- input_top_logprobs_val.append([])
|
|
- input_top_logprobs_idx.append([])
|
|
+ if n == 0:
|
|
+ input_top_logprobs_val.append([])
|
|
+ input_top_logprobs_idx.append([])
|
|
continue
|
|
|
|
- input_top_logprobs_val.append(
|
|
- [values[pt + j][:k] for j in range(pruned_len)]
|
|
- )
|
|
- input_top_logprobs_idx.append(
|
|
- [indices[pt + j][:k] for j in range(pruned_len)]
|
|
- )
|
|
- pt += pruned_len
|
|
+ val = []
|
|
+ idx = []
|
|
+ for j in range(pruned_len):
|
|
+ # Handle remaining tokens in next chunk if any
|
|
+ if pt + j >= len(values):
|
|
+ next_split_pruned_len = split_pruned_len + j
|
|
+ break
|
|
+ val.append(values[pt + j][:k])
|
|
+ idx.append(indices[pt + j][:k])
|
|
+
|
|
+ if split_pruned_len <= 0 and len(val) > 0:
|
|
+ input_top_logprobs_val.append(val)
|
|
+ input_top_logprobs_idx.append(idx)
|
|
+ else:
|
|
+ input_top_logprobs_val[-1].extend(val)
|
|
+ input_top_logprobs_idx[-1].extend(idx)
|
|
|
|
- return input_top_logprobs_val, input_top_logprobs_idx
|
|
+ pt += pruned_len
|
|
+ return next_split_pruned_len
|
|
|
|
@staticmethod
|
|
def get_token_ids_logprobs(
|
|
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
|
+ logprobs: torch.Tensor,
|
|
+ logits_metadata: LogitsMetadata,
|
|
+ chunk_slice: slice,
|
|
+ input_token_ids_logprobs_val: List,
|
|
+ input_token_ids_logprobs_idx: List,
|
|
+ split_pruned_len: int = 0,
|
|
):
|
|
- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
|
+ """Get token_ids logprobs for each sequence in the chunk.
|
|
+
|
|
+ Args:
|
|
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
|
+ logits_metadata: Metadata containing token IDs and pruned length info
|
|
+ chunk_slice: Slice of sequences to process
|
|
+ input_token_ids_logprobs_val: List to store token logprob values
|
|
+ input_token_ids_logprobs_idx: List to store token indices
|
|
+ split_pruned_len: Length of pruned tokens from previous chunk
|
|
+
|
|
+ Returns:
|
|
+ int: Number of remaining tokens to process in next chunk
|
|
+ """
|
|
pt = 0
|
|
- for token_ids, pruned_len in zip(
|
|
- logits_metadata.token_ids_logprobs,
|
|
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
+ next_split_pruned_len = 0
|
|
+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice]
|
|
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
|
+
|
|
+ for n, (token_ids, pruned_len) in enumerate(
|
|
+ zip(
|
|
+ token_ids_logprobs_chunk,
|
|
+ pruned_lens,
|
|
+ )
|
|
):
|
|
+ # Adjust pruned length for first sequence
|
|
+ if n == 0:
|
|
+ pruned_len -= split_pruned_len
|
|
+ else:
|
|
+ split_pruned_len = 0
|
|
+
|
|
if pruned_len <= 0:
|
|
- input_token_ids_logprobs_val.append([])
|
|
- input_token_ids_logprobs_idx.append([])
|
|
+ if n == 0:
|
|
+ input_token_ids_logprobs_val.append([])
|
|
+ input_token_ids_logprobs_idx.append([])
|
|
continue
|
|
|
|
- input_token_ids_logprobs_val.append(
|
|
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
|
- )
|
|
- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
|
- pt += pruned_len
|
|
+ val = []
|
|
+ idx = []
|
|
+ for j in range(pruned_len):
|
|
+ # Handle remaining tokens in next chunk if any
|
|
+ if pt + j >= logprobs.shape[0]:
|
|
+ next_split_pruned_len = split_pruned_len + j
|
|
+ break
|
|
+ if token_ids is not None:
|
|
+ val.append(logprobs[pt + j, token_ids].tolist())
|
|
+ idx.append(token_ids)
|
|
+
|
|
+ if split_pruned_len <= 0 and len(val) > 0:
|
|
+ input_token_ids_logprobs_val.append(val)
|
|
+ input_token_ids_logprobs_idx.append(idx)
|
|
+ elif len(val) > 0:
|
|
+ input_token_ids_logprobs_val[-1].extend(val)
|
|
+ input_token_ids_logprobs_idx[-1].extend(idx)
|
|
|
|
- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
|
+ pt += pruned_len
|
|
+ return next_split_pruned_len
|
|
|
|
@staticmethod
|
|
def compute_temp_top_p_normalized_logprobs(
|
|
- last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
|
+ last_logits: torch.Tensor,
|
|
+ indices: torch.Tensor,
|
|
+ logits_metadata: LogitsMetadata,
|
|
) -> torch.Tensor:
|
|
"""
|
|
compute logprobs for the output token from the given logits.
|
|
@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module):
|
|
"""
|
|
# Scale logits if temperature scaling is enabled
|
|
if logits_metadata.temp_scaled_logprobs:
|
|
- last_logits = last_logits / logits_metadata.temperature
|
|
+ last_logits = last_logits / logits_metadata.temperature[indices]
|
|
+
|
|
+ top_p = None
|
|
+ if logits_metadata.top_p is not None:
|
|
+ top_p = logits_metadata.top_p[indices]
|
|
|
|
# Normalize logprobs if top_p normalization is enabled
|
|
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
|
- if (
|
|
- logits_metadata.top_p_normalized_logprobs
|
|
- and (logits_metadata.top_p != 1.0).any()
|
|
- ):
|
|
+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
|
|
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
|
|
|
probs = torch.softmax(last_logits, dim=-1)
|
|
del last_logits
|
|
- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
|
+ probs = top_p_normalize_probs_torch(probs, top_p)
|
|
return torch.log(probs)
|
|
else:
|
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
|
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
|
index 5390668c..db370d19 100644
|
|
--- a/python/sglang/srt/managers/io_struct.py
|
|
+++ b/python/sglang/srt/managers/io_struct.py
|
|
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
|
success: bool
|
|
|
|
|
|
+@dataclass
|
|
+class InterruptAllReqInput:
|
|
+ pass
|
|
+
|
|
+
|
|
+@dataclass
|
|
+class InterruptAllReqOutput:
|
|
+ num_interrupted_requests: int
|
|
+
|
|
+
|
|
@dataclass
|
|
class UpdateWeightFromDiskReqInput:
|
|
# The model path with the new weights
|
|
model_path: str
|
|
+ allow_interrupt: bool = False
|
|
# The format to load the weights
|
|
load_format: Optional[str] = None
|
|
|
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
|
index 1178eec5..318dee33 100644
|
|
--- a/python/sglang/srt/managers/scheduler.py
|
|
+++ b/python/sglang/srt/managers/scheduler.py
|
|
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
|
HealthCheckOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
+ InterruptAllReqInput,
|
|
+ InterruptAllReqOutput,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
ProfileReq,
|
|
@@ -427,6 +429,7 @@ class Scheduler(
|
|
# Init request dispatcher
|
|
self._request_dispatcher = TypeBasedDispatcher(
|
|
[
|
|
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
|
@@ -1971,6 +1974,15 @@ class Scheduler(
|
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
raise NotImplementedError()
|
|
|
|
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
|
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
|
+ for req in self.waiting_queue:
|
|
+ req.sampling_params.max_new_tokens = 0
|
|
+ for req in self.running_batch.reqs:
|
|
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
|
+ logger.info(f"Interrupt {num} requests.")
|
|
+ return InterruptAllReqOutput(num)
|
|
+
|
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
"""In-place update of the weights from disk."""
|
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
|
index b646fae1..c668728b 100644
|
|
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
|
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
|
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
|
|
HealthCheckOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
+ InterruptAllReqInput,
|
|
+ InterruptAllReqOutput,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
ProfileReq,
|
|
@@ -279,6 +281,9 @@ class TokenizerManager:
|
|
self.slow_down_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
+ self.interrupt_requests_communicator = _Communicator(
|
|
+ self.send_to_scheduler, server_args.dp_size
|
|
+ )
|
|
self.flush_cache_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
@@ -309,6 +314,10 @@ class TokenizerManager:
|
|
UpdateWeightFromDiskReqOutput,
|
|
self._handle_update_weights_from_disk_req_output,
|
|
),
|
|
+ (
|
|
+ InterruptAllReqOutput,
|
|
+ self.interrupt_requests_communicator.handle_recv,
|
|
+ ),
|
|
(
|
|
InitWeightsUpdateGroupReqOutput,
|
|
self.init_weights_update_group_communicator.handle_recv,
|
|
@@ -799,6 +808,13 @@ class TokenizerManager:
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
|
|
+ if obj.allow_interrupt:
|
|
+ num_interrupted_requests = await self.interrupt_all_requests(
|
|
+ InterruptAllReqInput()
|
|
+ )
|
|
+ # Set a break point to wait for the interrupt to finish
|
|
+ await asyncio.sleep(0.1)
|
|
+
|
|
# default the load format to the server_args
|
|
if obj.load_format is None:
|
|
obj.load_format = self.server_args.load_format
|
|
@@ -808,7 +824,12 @@ class TokenizerManager:
|
|
# Hold the lock if it is not async. This means that weight sync
|
|
# cannot run while requests are in progress.
|
|
async with self.model_update_lock.writer_lock:
|
|
- return await self._wait_for_model_update_from_disk(obj)
|
|
+ success, message, n_paused = (
|
|
+ await self._wait_for_model_update_from_disk(obj)
|
|
+ )
|
|
+ if obj.allow_interrupt:
|
|
+ return success, message, num_interrupted_requests
|
|
+ return success, message, n_paused
|
|
|
|
async def _wait_for_model_update_from_disk(
|
|
self, obj: UpdateWeightFromDiskReqInput
|
|
@@ -881,6 +902,18 @@ class TokenizerManager:
|
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
+ async def interrupt_all_requests(
|
|
+ self,
|
|
+ obj: InterruptAllReqInput,
|
|
+ request: Optional[fastapi.Request] = None,
|
|
+ ) -> Tuple[bool, str]:
|
|
+ self.auto_create_handle_loop()
|
|
+ result = await self.interrupt_requests_communicator(obj)
|
|
+ if self.server_args.dp_size == 1:
|
|
+ return result[0].num_interrupted_requests
|
|
+ else:
|
|
+ return [r.num_interrupted_requests for r in result]
|
|
+
|
|
async def get_weights_by_name(
|
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
):
|