Skip to content

Commit

Permalink
Fix chat completion url for OpenAI compatibility (#2418)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jul 29, 2024
1 parent 09fc1d4 commit 9e05812
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,16 @@ def chat_completion(
# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
model_url = model.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down
7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,13 +825,16 @@ async def chat_completion(
# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
model_url = model.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down
21 changes: 21 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,27 @@ def test_model_and_base_url_mutually_exclusive(self):
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")


@pytest.mark.parametrize(
"base_url",
[
"http://0.0.0.0:8080/v1", # expected from OpenAI client
"http://0.0.0.0:8080", # but not mandatory
"http://0.0.0.0:8080/v1/", # ok with trailing '/' as well
"http://0.0.0.0:8080/", # ok with trailing '/' as well
],
)
def test_chat_completion_base_url_works_with_v1(base_url: str):
"""Test that `/v1/chat/completions` is correctly appended to the base URL.
This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414
"""
with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock:
client = InferenceClient(base_url=base_url)
post_mock.return_value = "{}"
client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False)
assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions"


def test_stream_text_generation_response():
data = [
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}',
Expand Down

0 comments on commit 9e05812

Please sign in to comment.