AReaL/patch/sglang/v0.4.6.post4.patch

643 lines
26 KiB
Diff

diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 60091b9a..7a1c856b 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -15,6 +15,7 @@
import dataclasses
import logging
+import os
from typing import List, Optional, Union
import torch
@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.utils import dump_to_file
+# When compute the input and output tokens logprobs, if the rows of the
+# logprobs are too large, the peak memory usage will be too high. For example,
+# if the logprobs are [10000, 150000], the peak memory usage will be greater
+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float)
+# So we split the logprobs into multiple chunks.
+LOGITS_PROCESSER_CHUNK_SIZE = int(
+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048")
+)
+
logger = logging.getLogger(__name__)
@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None
else:
# Input logprobs are required.
- # Find 3 different indices.
+ # Find 4 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices.
sample_index_pt = -1
sample_indices = []
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []
+ idx, sequence_index_mapping = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module):
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
+ # sequence_index_mapping, repeat this for loop index
+ sequence_index_mapping.extend([idx] * (extend_len - start_len))
+ idx += 1
pt += extend_len
+
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
input_logprob_indices.extend(
@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module):
)
input_logprob_indices_pt += extend_len - start_len
+ sequence_index_mapping.append(idx - 1)
pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module):
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
- # Compute logits for both input and sampled tokens.
- logits = self._get_logits(pruned_states, lm_head, logits_metadata)
- sampled_logits = (
- logits[sample_indices] if sample_indices is not None else logits
- )
-
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather
@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module):
else:
assert False, "Should never reach"
+ del hidden_states
+
if not logits_metadata.extend_return_logprob:
+ # Compute logits for both input and sampled tokens.
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
+ sampled_logits = (
+ logits[sample_indices] if sample_indices is not None else logits
+ )
+
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
)
else:
- input_logprobs = logits[input_logprob_indices]
- del hidden_states, logits
+ # Compute logprobs requires lot of memory, so we split pruned_states
+ # into chunks of rows to compute input_logprobs separately, then
+ # concatenate the results.
+ return self._compute_output_by_chunk(
+ pruned_states,
+ sample_indices,
+ hidden_states_to_store,
+ input_logprob_indices,
+ sequence_index_mapping,
+ lm_head,
+ logits_metadata,
+ )
- # Normalize the logprob w/o temperature, top-p
- pruned_lens = torch.tensor(
- logits_metadata.extend_logprob_pruned_lens_cpu,
- device=input_logprobs.device,
+ def _compute_output_by_chunk(
+ self,
+ pruned_states: torch.Tensor,
+ sample_indices: torch.Tensor,
+ hidden_states_to_store: Optional[torch.Tensor],
+ input_logprob_indices: torch.Tensor,
+ index_mapping: list[int],
+ lm_head: VocabParallelEmbedding,
+ logits_metadata: LogitsMetadata,
+ ) -> LogitsProcessorOutput:
+ """
+ compute logprobs for the output token from the hidden states.
+ To avoid using too much memory, we split pruned_states into chunks of
+ rows to compute input_logprobs separately, then concatenate the results.
+
+ Returns:
+ LogitsProcessorOutput: logits processor output class
+ """
+
+ # Normalize the logprob w/o temperature, top-p
+ pruned_lens = torch.tensor(
+ logits_metadata.extend_logprob_pruned_lens_cpu,
+ device=pruned_states.device,
+ )
+ if logits_metadata.temp_scaled_logprobs:
+ logits_metadata.temperature = torch.repeat_interleave(
+ logits_metadata.temperature.view(-1),
+ pruned_lens,
+ ).view(-1, 1)
+ if logits_metadata.top_p_normalized_logprobs:
+ logits_metadata.top_p = torch.repeat_interleave(
+ logits_metadata.top_p,
+ pruned_lens,
)
- if logits_metadata.temp_scaled_logprobs:
- logits_metadata.temperature = torch.repeat_interleave(
- logits_metadata.temperature.view(-1),
- pruned_lens,
- ).view(-1, 1)
- if logits_metadata.top_p_normalized_logprobs:
- logits_metadata.top_p = torch.repeat_interleave(
- logits_metadata.top_p,
- pruned_lens,
- )
- input_logprobs = self.compute_temp_top_p_normalized_logprobs(
- input_logprobs, logits_metadata
+
+ # The peak memory usage is proportional to the chunk size.
+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE
+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size
+
+ input_token_logprobs = []
+ if logits_metadata.extend_return_top_logprob:
+ input_top_logprobs_val = []
+ input_top_logprobs_idx = []
+ else:
+ input_top_logprobs_val = None
+ input_top_logprobs_idx = None
+
+ if logits_metadata.extend_token_ids_logprob:
+ input_token_ids_logprobs_val = []
+ input_token_ids_logprobs_idx = []
+ else:
+ input_token_ids_logprobs_val = None
+ input_token_ids_logprobs_idx = None
+
+ # It a single sequence is split into multiple chunks, we need to keep track
+ # of the pruned length of the sequences in the previous chunks.
+ split_len_topk = 0
+ split_len_token_ids = 0
+
+ for i in range(num_chunks):
+ start_idx = i * chunk_size
+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0])
+
+ # Get indices for this chunk
+ chunk_mask = (input_logprob_indices >= start_idx) & (
+ input_logprob_indices < end_idx
+ )
+
+ global_indices = input_logprob_indices[chunk_mask]
+ chunk_indices = global_indices - start_idx
+ chunk_states = pruned_states[start_idx:end_idx]
+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
+
+ if chunk_indices.numel() == 0:
+ continue
+
+ # Compute the logprobs of the chunk
+ chunk_input_logprobs = chunk_logits[chunk_indices]
+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
+ chunk_input_logprobs, global_indices, logits_metadata
)
+ # For each chunk, we need to get the slice of the sequence_index_mapping
+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1)
+
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
- (
+ split_len_topk = self.get_top_logprobs(
+ chunk_input_logprobs,
+ logits_metadata,
+ chunk_slice,
input_top_logprobs_val,
input_top_logprobs_idx,
- ) = self.get_top_logprobs(input_logprobs, logits_metadata)
- else:
- input_top_logprobs_val = input_top_logprobs_idx = None
+ split_len_topk,
+ )
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
- (
+ split_len_token_ids = self.get_token_ids_logprobs(
+ chunk_input_logprobs,
+ logits_metadata,
+ chunk_slice,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
- else:
- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
-
- input_token_logprobs = input_logprobs[
- torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
- logits_metadata.extend_input_logprob_token_ids_gpu,
- ]
+ split_len_token_ids,
+ )
- return LogitsProcessorOutput(
- next_token_logits=sampled_logits,
- input_token_logprobs=input_token_logprobs,
- input_top_logprobs_val=input_top_logprobs_val,
- input_top_logprobs_idx=input_top_logprobs_idx,
- hidden_states=hidden_states_to_store,
- input_token_ids_logprobs_val=input_token_ids_logprobs_val,
- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
+ # Handle sampled logits for the chunk if needed
+ chunk_sample_mask = (sample_indices >= start_idx) & (
+ sample_indices < end_idx
)
+ if i == 0: # Initialize sampled_logits on first chunk
+ sampled_logits = torch.empty(
+ (sample_indices.shape[0], chunk_logits.shape[1]),
+ dtype=chunk_logits.dtype,
+ device=chunk_logits.device,
+ )
+ if chunk_sample_mask.any():
+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
+
+ # Get the logprob of the requested token ids
+ chunk_input_token_logprobs = chunk_input_logprobs[
+ torch.arange(
+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
+ ),
+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx],
+ ]
+ input_token_logprobs.append(chunk_input_token_logprobs)
+
+ # Concatenate the results
+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
+
+ return LogitsProcessorOutput(
+ hidden_states=hidden_states_to_store,
+ next_token_logits=sampled_logits,
+ input_token_logprobs=input_token_logprobs,
+ input_top_logprobs_val=input_top_logprobs_val,
+ input_top_logprobs_idx=input_top_logprobs_idx,
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
+ )
def _get_logits(
self,
@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module):
return logits
@staticmethod
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
+ def get_top_logprobs(
+ logprobs: torch.Tensor,
+ logits_metadata: LogitsMetadata,
+ chunk_slice: slice,
+ input_top_logprobs_val: List,
+ input_top_logprobs_idx: List,
+ split_pruned_len: int,
+ ):
+ """Get top-k logprobs for each sequence in the chunk.
+
+ Args:
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
+ logits_metadata: Metadata containing top-k and pruned length info
+ chunk_slice: Slice of sequences to process
+ input_top_logprobs_val: List to store top-k logprob values
+ input_top_logprobs_idx: List to store top-k token indices
+ split_pruned_len: Length of pruned tokens from previous chunk
+
+ Returns:
+ int: Number of remaining tokens to process in next chunk
+ """
+
max_k = max(logits_metadata.top_logprobs_nums)
- ret = all_logprobs.topk(max_k, dim=1)
+ ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
- input_top_logprobs_val, input_top_logprobs_idx = [], []
-
pt = 0
- for k, pruned_len in zip(
- logits_metadata.top_logprobs_nums,
- logits_metadata.extend_logprob_pruned_lens_cpu,
- ):
+ next_split_pruned_len = 0
+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
+
+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
+ # Adjust pruned length for first sequence
+ if n == 0:
+ pruned_len -= split_pruned_len
+ else:
+ split_pruned_len = 0
+
if pruned_len <= 0:
- input_top_logprobs_val.append([])
- input_top_logprobs_idx.append([])
+ if n == 0:
+ input_top_logprobs_val.append([])
+ input_top_logprobs_idx.append([])
continue
- input_top_logprobs_val.append(
- [values[pt + j][:k] for j in range(pruned_len)]
- )
- input_top_logprobs_idx.append(
- [indices[pt + j][:k] for j in range(pruned_len)]
- )
- pt += pruned_len
+ val = []
+ idx = []
+ for j in range(pruned_len):
+ # Handle remaining tokens in next chunk if any
+ if pt + j >= len(values):
+ next_split_pruned_len = split_pruned_len + j
+ break
+ val.append(values[pt + j][:k])
+ idx.append(indices[pt + j][:k])
+
+ if split_pruned_len <= 0 and len(val) > 0:
+ input_top_logprobs_val.append(val)
+ input_top_logprobs_idx.append(idx)
+ else:
+ input_top_logprobs_val[-1].extend(val)
+ input_top_logprobs_idx[-1].extend(idx)
- return input_top_logprobs_val, input_top_logprobs_idx
+ pt += pruned_len
+ return next_split_pruned_len
@staticmethod
def get_token_ids_logprobs(
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
+ logprobs: torch.Tensor,
+ logits_metadata: LogitsMetadata,
+ chunk_slice: slice,
+ input_token_ids_logprobs_val: List,
+ input_token_ids_logprobs_idx: List,
+ split_pruned_len: int = 0,
):
- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
+ """Get token_ids logprobs for each sequence in the chunk.
+
+ Args:
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
+ logits_metadata: Metadata containing token IDs and pruned length info
+ chunk_slice: Slice of sequences to process
+ input_token_ids_logprobs_val: List to store token logprob values
+ input_token_ids_logprobs_idx: List to store token indices
+ split_pruned_len: Length of pruned tokens from previous chunk
+
+ Returns:
+ int: Number of remaining tokens to process in next chunk
+ """
pt = 0
- for token_ids, pruned_len in zip(
- logits_metadata.token_ids_logprobs,
- logits_metadata.extend_logprob_pruned_lens_cpu,
+ next_split_pruned_len = 0
+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice]
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
+
+ for n, (token_ids, pruned_len) in enumerate(
+ zip(
+ token_ids_logprobs_chunk,
+ pruned_lens,
+ )
):
+ # Adjust pruned length for first sequence
+ if n == 0:
+ pruned_len -= split_pruned_len
+ else:
+ split_pruned_len = 0
+
if pruned_len <= 0:
- input_token_ids_logprobs_val.append([])
- input_token_ids_logprobs_idx.append([])
+ if n == 0:
+ input_token_ids_logprobs_val.append([])
+ input_token_ids_logprobs_idx.append([])
continue
- input_token_ids_logprobs_val.append(
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
- )
- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
- pt += pruned_len
+ val = []
+ idx = []
+ for j in range(pruned_len):
+ # Handle remaining tokens in next chunk if any
+ if pt + j >= logprobs.shape[0]:
+ next_split_pruned_len = split_pruned_len + j
+ break
+ if token_ids is not None:
+ val.append(logprobs[pt + j, token_ids].tolist())
+ idx.append(token_ids)
+
+ if split_pruned_len <= 0 and len(val) > 0:
+ input_token_ids_logprobs_val.append(val)
+ input_token_ids_logprobs_idx.append(idx)
+ elif len(val) > 0:
+ input_token_ids_logprobs_val[-1].extend(val)
+ input_token_ids_logprobs_idx[-1].extend(idx)
- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
+ pt += pruned_len
+ return next_split_pruned_len
@staticmethod
def compute_temp_top_p_normalized_logprobs(
- last_logits: torch.Tensor, logits_metadata: LogitsMetadata
+ last_logits: torch.Tensor,
+ indices: torch.Tensor,
+ logits_metadata: LogitsMetadata,
) -> torch.Tensor:
"""
compute logprobs for the output token from the given logits.
@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module):
"""
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
- last_logits = last_logits / logits_metadata.temperature
+ last_logits = last_logits / logits_metadata.temperature[indices]
+
+ top_p = None
+ if logits_metadata.top_p is not None:
+ top_p = logits_metadata.top_p[indices]
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
- if (
- logits_metadata.top_p_normalized_logprobs
- and (logits_metadata.top_p != 1.0).any()
- ):
+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
+ probs = top_p_normalize_probs_torch(probs, top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):