Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow payload request to support extra inference method kwargs #1505

Closed
wants to merge 23 commits into from
5 changes: 5 additions & 0 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value
if request.parameters is not None:
sakoush marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(request.parameters, "extra"):
extra = request.parameters.extra
if isinstance(extra, dict):
values.update(extra)
sakoush marked this conversation as resolved.
Show resolved Hide resolved
return values


Expand Down
63 changes: 63 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch

import json
import pytest
import torch
from typing import Dict, Optional
Expand All @@ -13,6 +14,8 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import HuggingFaceSettings
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.types.dataplane import Parameters


@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,3 +213,63 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs1, inference_kwargs2, expected",
[
(
{"max_length": 20},
{"max_length": 10},
True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear what expected means given the test case, i suggest to refactor a bit to make it more clearer. It might be just you just need to assert the number of tokens in each request is as follows expected (effectively converting it to 2 test cases)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this into 2 test cases. Also asserting the number of predicted tokens are the expected number of tokens

)
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs1: Optional[dict],
inference_kwargs2: Optional[dict],
expected: bool,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
sakoush marked this conversation as resolved.
Show resolved Hide resolved
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload1 = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs1),
)
payload2 = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs2),
)

result1 = await runtime.predict(payload1)
generated_text1 = json.loads(result1.outputs[0].data[0])["generated_text"]
assert isinstance(generated_text1, str)
result2 = await runtime.predict(payload2)
generated_text2 = json.loads(result2.outputs[0].data[0])["generated_text"]
assert isinstance(generated_text2, str)
comparison = len(generated_text1) > len(generated_text2)
assert comparison == expected