Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Response prefill logprobs seems to become incorrect when using AsyncInferenceClient in some circumstances #2502

Open
sadra-barikbin opened this issue Sep 6, 2024 · 2 comments

Comments

@sadra-barikbin
Copy link
Contributor

Hi there!🤗

I came across a mismatch on prefill logprobs when sending requests to TGI asynchronously as opposed to sending them synchronously.

Here is a reproduction of the issue:

import time
import random
import asyncio
from typing import Iterator

import pytest
import docker
import requests
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from huggingface_hub import (
    InferenceClient,
    AsyncInferenceClient,
    TextGenerationOutput,
)


@pytest.fixture(params=["sync", "async"])
def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]:
    client = docker.from_env()

    try:
        container = client.containers.get("lighteval-tgi-model-test")
        port = container.ports["80/tcp"][0]["HostPort"]
    except docker.errors.NotFound:
        port = random.randint(8000, 9000)
        container = client.containers.run(
            "ghcr.io/huggingface/text-generation-inference:2.2.0",
            command=[
                "--model-id",
                "hf-internal-testing/tiny-random-LlamaForCausalLM",
                "--dtype",
                "float16",
            ],
            detach=True,
            name="lighteval-tgi-model-test",
            auto_remove=False,
            ports={"80/tcp": port},
        )
    address = f"http://localhost:{port}"
    for _ in range(40):
        try:
            if requests.get(f"{address}/health"):
                break
        except Exception:
            time.sleep(1)
    else:
        raise RuntimeError("Couldn't setup TGI server.")
    
    if request.param == "async":
        yield AsyncInferenceClient(base_url=address)    
    elif request.param == "sync":
        yield InferenceClient(base_url=address)
    else:
        raise RuntimeError()


def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient):
    model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
    tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
    
    # It raises error in async setting unless the size of `prompts` is < 3
    prompts = [
        "Tell me:\n\nHow are you?Fine, thanks!",
        "Tell me:\n\nHow are you?Not bad!",
        "Tell me:\n\nComment vas-tu?Comme ci, comme ça",
        "Tell me:\n\nComment vas-tu?Ca va! Merci!",
    ]
    responses = []
    for prompt in prompts:
        responses.append(tgi_client.text_generation(
            prompt,
            details=True,
            decoder_input_details=True,
            max_new_tokens=1,
            stop_sequences=None,
            do_sample=False,
            return_full_text=False,
            seed=42,
        ))
    if isinstance(tgi_client, AsyncInferenceClient):
        loop = asyncio.get_event_loop()
        responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses))
    
    error = False
    for prompt, response in zip(prompts, responses):

        tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping <s> whose logprob is None
        
        tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids']
        output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
        with torch.no_grad():
            logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
        logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze()
        
        if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()):
            print(f"====== prompt: {repr(prompt)} ======")
            print("TGI logprobs:", tgi_logprobs.tolist())
            print("TGI tokens:",[t.id for t in response.details.prefill])
            print("Ref. logprobs:", logprobs.tolist())
            print("Ref. tokens:", tokenized_sequence[0].tolist())
            error = True
    assert not error
====== prompt: 'Tell me:\n\nHow are you?Not bad!' ======
TGI logprobs: [-10.390103340148926, -10.234614372253418, -10.583305358886719, -10.415820121765137, -10.48030948638916, -10.504646301269531, -10.360091209411621, -10.294694900512695, -10.363240242004395, -10.283424377441406, -10.3839693069458, -10.447381019592285]
TGI tokens: [1, 24948, 592, 29901, 13, 13, 5328, 526, 366, 29973, 3664, 4319, 29991]
Ref. logprobs: [-10.23465633392334, -10.583168983459473, -10.415797233581543, -10.480293273925781, -10.504618644714355, -10.360037803649902, -10.294654846191406, -10.363250732421875, -10.283416748046875, -10.38398265838623, -10.44735050201416, -10.320746421813965]
Ref. tokens: [1, 24948, 592, 29901, 13, 13, 5328, 526, 366, 29973, 3664, 4319, 29991]
====== prompt: 'Tell me:\n\nComment vas-tu?Ca va! Merci!' ======
TGI logprobs: [-10.385672569274902, -10.234614372253418, -10.583305358886719, -10.415820121765137, -10.48030948638916, -10.504646301269531, -10.199889183044434, -10.274925231933594, -10.404882431030273, -10.23915958404541, -10.417717933654785, -10.42142391204834, -10.367297172546387, -10.263328552246094, -10.426572799682617, -10.497278213500977]
TGI tokens: [1, 24948, 592, 29901, 13, 13, 20001, 19723, 29899, 9161, 29973, 26270, 2947, 29991, 4702, 455, 29991]
Ref. logprobs: [-10.23465633392334, -10.583168983459473, -10.415797233581543, -10.480293273925781, -10.504618644714355, -10.19986629486084, -10.274916648864746, -10.404905319213867, -10.239229202270508, -10.417610168457031, -10.421340942382812, -10.367280960083008, -10.263335227966309, -10.42658805847168, -10.49729061126709, -10.374680519104004]
Ref. tokens: [1, 24948, 592, 29901, 13, 13, 20001, 19723, 29899, 9161, 29973, 26270, 2947, 29991, 4702, 455, 29991]

It seems that in the error cases, TGI logprobs has an additional item on the left and a missing one on the right.

@ErikKaum
Copy link
Member

ErikKaum commented Sep 9, 2024

Hi @sadra-barikbin 👋

Thanks for reporting this.
Btw, quick question: do you suspect this is in the TGI server or something that goes wrong in the AsyncInferenceClient?

I'll ping @Wauplin just in case.

@sadra-barikbin
Copy link
Contributor Author

@ErikKaum , It seems to be rooted in TGI. Here is the reproduction with httpx:

import time
import random
import asyncio
import httpx
from typing import Iterator

import pytest
import docker
import requests
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM


@pytest.fixture(params=["sync", "async"])
def tgi_client(request) -> Iterator[httpx.Client|httpx.AsyncClient]:
    client = docker.from_env()

    try:
        container = client.containers.get("lighteval-tgi-model-test")
        port = container.ports["80/tcp"][0]["HostPort"]
    except docker.errors.NotFound:
        port = random.randint(8000, 9000)
        container = client.containers.run(
            "ghcr.io/huggingface/text-generation-inference:2.2.0",
            command=[
                "--model-id",
                "hf-internal-testing/tiny-random-LlamaForCausalLM",
                "--dtype",
                "float16",
            ],
            detach=True,
            name="lighteval-tgi-model-test",
            auto_remove=False,
            ports={"80/tcp": port},
        )
    address = f"http://localhost:{port}"
    for _ in range(40):
        try:
            if requests.get(f"{address}/health"):
                break
        except Exception:
            time.sleep(1)
    else:
        raise RuntimeError("Couldn't setup TGI server.")
    
    if request.param == "async":
        yield httpx.AsyncClient(base_url=address)    
    elif request.param == "sync":
        yield httpx.Client(base_url=address)
    else:
        raise RuntimeError()


def test_logprobs(tgi_client: httpx.Client|httpx.AsyncClient):
    model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
    tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
    
    # It raises error in async setting unless the size of `prompts` is < 3
    prompts = [
        "Tell me:\n\nHow are you?Fine, thanks!",
        "Tell me:\n\nHow are you?Not bad!",
        "Tell me:\n\nComment vas-tu?Comme ci, comme ça",
        "Tell me:\n\nComment vas-tu?Ca va! Merci!",
    ]
    responses = []
    for prompt in prompts:
        responses.append(tgi_client.post("/generate",json={
            "inputs": prompt,
            "parameters":{
                "details": True,
                "decoder_input_details": True,
                "max_new_tokens": 1,
                "stop_sequences": [],
                "do_sample": False,
                "return_full_text": False,
                "seed": 42,
            },
            "stream": False,
        }))
    if isinstance(tgi_client, httpx.AsyncClient):
        loop = asyncio.get_event_loop()
        responses: list[httpx.Response] = loop.run_until_complete(asyncio.gather(*responses))
    
    error = False
    for prompt, response in zip(prompts, responses):
        response: dict = response.json()
        tgi_logprobs = torch.tensor([t["logprob"] for t in response["details"]["prefill"][1:]]) # Skipping <s> whose logprob is None
        
        tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids']
        output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
        with torch.no_grad():
            logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
        logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze()
        
        if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()):
            print(f"====== prompt: {repr(prompt)} ======")
            print("TGI logprobs:", tgi_logprobs.tolist())
            print("TGI tokens:",[t["id"] for t in response["details"]["prefill"]])
            print("Ref. logprobs:", logprobs.tolist())
            print("Ref. tokens:", tokenized_sequence[0].tolist())
            error = True
    assert not error
====== prompt: 'Tell me:\n\nHow are you?Not bad!' ======
TGI logprobs: [-10.390103340148926, -10.234614372253418, -10.583305358886719, -10.415820121765137, -10.48030948638916, -10.504646301269531, -10.360091209411621, -10.294694900512695, -10.363240242004395, -10.283424377441406, -10.3839693069458, -10.447381019592285]
TGI tokens: [1, 24948, 592, 29901, 13, 13, 5328, 526, 366, 29973, 3664, 4319, 29991]
Ref. logprobs: [-10.23465633392334, -10.583168983459473, -10.415797233581543, -10.480293273925781, -10.504618644714355, -10.360037803649902, -10.294654846191406, -10.363250732421875, -10.283416748046875, -10.38398265838623, -10.44735050201416, -10.320746421813965]
Ref. tokens: [1, 24948, 592, 29901, 13, 13, 5328, 526, 366, 29973, 3664, 4319, 29991]
====== prompt: 'Tell me:\n\nComment vas-tu?Ca va! Merci!' ======
TGI logprobs: [-10.385672569274902, -10.234614372253418, -10.583305358886719, -10.415820121765137, -10.48030948638916, -10.504646301269531, -10.199889183044434, -10.274925231933594, -10.404882431030273, -10.23915958404541, -10.417717933654785, -10.42142391204834, -10.367297172546387, -10.263328552246094, -10.426572799682617, -10.497278213500977]
TGI tokens: [1, 24948, 592, 29901, 13, 13, 20001, 19723, 29899, 9161, 29973, 26270, 2947, 29991, 4702, 455, 29991]
Ref. logprobs: [-10.23465633392334, -10.583168983459473, -10.415797233581543, -10.480293273925781, -10.504618644714355, -10.19986629486084, -10.274916648864746, -10.404905319213867, -10.239229202270508, -10.417610168457031, -10.421340942382812, -10.367280960083008, -10.263335227966309, -10.42658805847168, -10.49729061126709, -10.374680519104004]
Ref. tokens: [1, 24948, 592, 29901, 13, 13, 20001, 19723, 29899, 9161, 29973, 26270, 2947, 29991, 4702, 455, 29991]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants