Skip to content

Commit

Permalink
[fix] Fix prefix caching for multi-image/video (#2239)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Nov 28, 2024
1 parent 65fdb28 commit b7038fe
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 22 deletions.
36 changes: 21 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,17 @@ def from_dict(obj, vocab_size):
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hashes=hash(tuple(obj["image_hashes"])),
image_hashes=obj["image_hashes"],
)
image_hash = ret.image_hashes
ret.pad_values = [
(image_hash) % vocab_size,
(image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size,
]
if not isinstance(ret.image_hashes, list):
ret.pad_values = [
(ret.image_hashes) % vocab_size,
(ret.image_hashes >> 16) % vocab_size,
(ret.image_hashes >> 32) % vocab_size,
(ret.image_hashes >> 64) % vocab_size,
]
else:
ret.pad_values = [x % vocab_size for x in ret.image_hashes]

optional_args = [
"image_sizes",
Expand All @@ -171,14 +173,18 @@ def from_dict(obj, vocab_size):
def merge(self, other, vocab_size):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
self.image_hashes += other.image_hashes

self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
self.image_hashes += other.image_hashes
self.pad_values = [x % vocab_size for x in self.image_hashes]
else:
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]

optional_args = [
"image_sizes",
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
else:
image_aspect_ratio = "anyres"
offset_list = []
for image_s in image_sizes:
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
Expand Down Expand Up @@ -92,18 +92,14 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)

pad_ids = pad_values * (
(new_image_feature_len + len(pad_values)) // len(pad_values)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ [pad_values[image_idx]] * new_image_feature_len
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
return num_image_tokens

# Use grid_t * grid_w * grid_h to pad tokens for each image
# and replaced padding by unique image hash
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
Expand Down
2 changes: 2 additions & 0 deletions test/srt/test_session_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def test_session_control(self):
assert response["meta_info"]["finish_reason"]["type"] == "abort"

# 2. not use session control
requests.post(self.base_url + "/flush_cache")

input_ids_first_req = None
input_ids = []
outputs_normal = []
Expand Down

0 comments on commit b7038fe

Please sign in to comment.