Skip to content

Commit

Permalink
support session control for vision language models
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Nov 27, 2024
1 parent 0b46b95 commit e1f9938
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 21 deletions.
15 changes: 6 additions & 9 deletions python/sglang/srt/managers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ async def process_images_async(
if not image_data:
return None

modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
Expand All @@ -139,9 +140,12 @@ async def process_images_async(
else None
)

if isinstance(image_data, str):
image_data = [image_data]

if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
Expand All @@ -166,13 +170,6 @@ async def process_images_async(
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")

Expand Down
38 changes: 37 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -167,6 +168,30 @@ def from_dict(obj, vocab_size):

return ret

def merge(self, other, vocab_size):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
self.image_hashes += other.image_hashes

self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]

optional_args = [
"image_sizes",
"image_offsets",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))


class Req:
"""The input and output status of a request."""
Expand All @@ -177,14 +202,19 @@ def __init__(
origin_input_text: str,
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
):
# Input and output info
self.rid = rid
self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
self.origin_input_ids_unpadded = (
origin_input_ids_unpadded
if origin_input_ids_unpadded
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
Expand Down Expand Up @@ -260,6 +290,12 @@ def __init__(
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0

def extend_image_inputs(self, image_inputs, vocab_size):
if self.image_inputs is None:
self.image_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs, vocab_size)

# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,13 @@ def handle_generate_request(

# Image inputs
if recv_req.image_inputs is not None:
req.image_inputs = ImageInputs.from_dict(
image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)

if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT(
Expand Down
19 changes: 15 additions & 4 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,27 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
]
+ req.input_ids
)
input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else:
input_ids = req.input_ids
input_ids_unpadded = req.input_ids
new_req = Req(
req.rid,
None,
input_ids,
req.sampling_params,
rid=req.rid,
origin_input_text=None,
origin_input_ids=input_ids,
origin_input_ids_unpadded=input_ids_unpadded,
sampling_params=req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
)
if len(self.reqs) > 0:
new_req.image_inputs = self.reqs[-1].image_inputs
new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT(
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values

# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
if image_inputs.modalities is not None and (
"multi-images" in image_inputs.modalities
or "video" in image_inputs.modalities
):
image_aspect_ratio = "pad"
else:
image_aspect_ratio = "anyres"
offset_list = []
for image_s in image_sizes:
if len(image_sizes) > 16:
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"test_triton_attention_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_session_control.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
Expand Down
Loading

0 comments on commit e1f9938

Please sign in to comment.