Skip to content

Commit

Permalink
Merge pull request #264 from DocShotgun/robust-length-checking
Browse files Browse the repository at this point in the history
Robust request length checking in generator
  • Loading branch information
kingbri1 authored Dec 27, 2024
2 parents 7878d35 + b994aae commit 7094938
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,17 +1307,49 @@ async def generate_gen(

# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
if context_len > self.config.max_seq_len:
raise ValueError(
f"Context length {context_len} is greater than max_seq_len "
f"{self.config.max_seq_len}"
)

# The second index will be the negative prompt if CFG is enabled
negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0

# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
kwargs.get("max_tokens"), self.config.max_seq_len - context_len
kwargs.get("max_tokens"),
self.config.max_seq_len - max(context_len, negative_context_len),
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1

# Determine if the negative context or the context length is bigger
context_to_check = max(negative_context_len, context_len)

# Check highest possible total length of request
if context_to_check + max_tokens > self.config.max_seq_len:
preamble = (
"Negative prompt request"
if negative_context_len > context_len
else "Request"
)

raise ValueError(
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)

# Check total required pages for CFG request to avoid overallocation
if negative_prompt and (
sum(
256 * math.ceil((context + max_tokens) / 256)
for context in (context_len, negative_context_len)
)
> self.cache_size
):
raise ValueError(
f"Total required page size for request "
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
f"is greater than cache_size {self.cache_size}"
)

# Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
Expand Down

0 comments on commit 7094938

Please sign in to comment.