Skip to content

Commit

Permalink
Logprobs fixes for streaming chat/completions.
Browse files Browse the repository at this point in the history
This also brings the two chat/completions code paths back into
alignment.
  • Loading branch information
zewt committed Jun 25, 2024
1 parent 7ebb0b2 commit 958e222
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]):

logprob_response = None

tokens = unwrap(generation.get("tokens"), [])
token_probs = unwrap(generation.get("token_probs"), [])
logprobs = unwrap(generation.get("logprobs"), [])
if token_probs:
tokens = unwrap(generation.get("tokens"), [])
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs = []
for output_token, token_logprob, top_logprobs in zip(
for generated_token, generated_token_logprob, top_logprobs in zip(
tokens, token_probs, logprobs, strict=True
):
completion_logprobs = [
Expand All @@ -62,8 +62,8 @@ def _create_response(generations: List[dict], model_name: Optional[str]):

collected_token_probs.append(
ChatCompletionLogprobChoice(
token=output_token,
logprob=token_logprob,
token=generated_token,
logprob=generated_token_logprob,
top_logprobs=completion_logprobs,
)
)
Expand Down Expand Up @@ -112,22 +112,30 @@ def _create_stream_chunk(
role="assistant", content=unwrap(generation.get("text"), "")
)

logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs.items()
]

generated_token = next(iter(token_probs))
token_prob_response = ChatCompletionLogprob(
token=generated_token,
logprob=token_probs[generated_token],
top_logprobs=top_logprobs,
)
tokens = unwrap(generation.get("tokens"), [])
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs = []
for generated_token, generated_token_logprob, top_logprobs in zip(
tokens, token_probs, logprobs, strict=True
):
completion_logprobs = [
ChatCompletionLogprob(token=token, logprob=token_logprob)
for token, token_logprob in top_logprobs.items()
]

collected_token_probs.append(
ChatCompletionLogprobChoice(
token=generated_token,
logprob=generated_token_logprob,
top_logprobs=completion_logprobs,
)
)

logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)

choice = ChatCompletionStreamChoice(
index=index,
Expand Down

0 comments on commit 958e222

Please sign in to comment.