Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow payload request to support extra inference method kwargs #1505

Closed
wants to merge 23 commits into from
19 changes: 19 additions & 0 deletions runtimes/huggingface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ Models in the HuggingFace hub can be loaded by specifying their name in `paramet
If `parameters.extra.pretrained_model` is specified, it takes precedence over `parameters.uri`.
````

#### Model Inference
Model inference is done by HuggingFace pipeline. It allows users to run inference on a batch of inputs. Extra inference kwargs can be kept in `parameters.extra`.
```{code-block} json

{
"inputs": [
{
"name": "text_inputs",
"shape": [1],
"datatype": "BYTES",
"data": ["My kitten's name is JoJo,","Tell me a story:"],
}
],
"parameters": {
"extra":{"max_new_tokens": 200,"return_full_text": false}
}
}
```

### Reference

You can find the full reference of the accepted extra settings for the
Expand Down
16 changes: 16 additions & 0 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Type, Any, Dict, List, Union, Sequence
from mlserver.codecs.utils import (
has_decoded,
Expand Down Expand Up @@ -170,6 +171,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:

@classmethod
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
"""
Decode Inference request into dictionary
extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra'
"""
values = {}
field_codecs = cls._find_decode_codecs(request)
for item in request.inputs:
Expand All @@ -181,6 +186,17 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value
if request.parameters is not None:
sakoush marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(request.parameters, "extra"):
extra = request.parameters.extra
if isinstance(extra, dict):
values.update(extra)
sakoush marked this conversation as resolved.
Show resolved Hide resolved
else:
logging.warn(
"Extra parameters is provided with "
+ f"value '{extra}' and type '{type(extra)}' \n"
+ "Extra parameters cannot be parsed, expected a dictionary."
)
return values


Expand Down
80 changes: 77 additions & 3 deletions runtimes/huggingface/tests/test_codecs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

import logging
from mlserver.types import (
InferenceRequest,
InferenceResponse,
Expand Down Expand Up @@ -28,15 +28,89 @@
]
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"]},
)
),
(
sakoush marked this conversation as resolved.
Show resolved Hide resolved
InferenceRequest(
parameters=Parameters(content_type="str", extra={"foo3": "var2"}),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"], "foo3": "var2"},
),
],
)
def test_decode_request(inference_request, expected):
payload = HuggingfaceRequestCodec.decode_request(inference_request)

assert payload == expected


@pytest.mark.parametrize(
"inference_request, expected_payload, expected_log_msg",
[
(
InferenceRequest(
parameters=Parameters(content_type="str", extra="foo3"),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with ",
+"value: 'foo3' and type '<class 'str'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra=1234),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with "
+ "value '1234' and type '<class 'int'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
],
)
def test_decode_request_with_invalid_parameter_extra(
inference_request, expected_payload, expected_log_msg, caplog
):
caplog.set_level(logging.WARN)
payload = HuggingfaceRequestCodec.decode_request(inference_request)
assert payload == expected_payload
assert expected_log_msg in caplog.text


@pytest.mark.parametrize(
"payload, use_bytes, expected",
[
Expand Down
49 changes: 49 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import HuggingFaceSettings
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.types.dataplane import Parameters
from mlserver_huggingface.codecs.base import MultiInputRequestCodec


@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,3 +213,49 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs, expected_num_tokens",
[
({"max_new_tokens": 10, "return_full_text": False}, 10),
({"max_new_tokens": 20, "return_full_text": False}, 20),
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs: Optional[dict],
expected_num_tokens: int,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
sakoush marked this conversation as resolved.
Show resolved Hide resolved
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs),
)
tokenizer = runtime._model.tokenizer

prediction = await runtime.predict(payload)
decoded_prediction = MultiInputRequestCodec.decode_response(prediction)
if isinstance(decoded_prediction, dict):
generated_text = decoded_prediction["output"][0]["generated_text"]
assert isinstance(generated_text, str)
tokenized_generated_text = tokenizer.tokenize(generated_text)
num_predicted_tokens = len(tokenized_generated_text)
assert num_predicted_tokens == expected_num_tokens