Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow payload request to support extra inference method kwargs #1505

Closed
wants to merge 23 commits into from
19 changes: 19 additions & 0 deletions runtimes/huggingface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ Models in the HuggingFace hub can be loaded by specifying their name in `paramet
If `parameters.extra.pretrained_model` is specified, it takes precedence over `parameters.uri`.
````

#### Model Inference
Model inference is done by HuggingFace pipeline. It allows users to run inference on a batch of inputs. Extra inference kwargs can be kept in `parameters.extra`.
```{code-block} json

{
"inputs": [
{
"name": "text_inputs",
"shape": [1],
"datatype": "BYTES",
"data": ["My kitten's name is JoJo,","Tell me a story:"],
}
],
"parameters": {
"extra":{"max_new_tokens": 200,"return_full_text": false}
}
}
```

### Reference

You can find the full reference of the accepted extra settings for the
Expand Down
14 changes: 14 additions & 0 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Type, Any, Dict, List, Union, Sequence
from mlserver.codecs.utils import (
has_decoded,
Expand Down Expand Up @@ -170,6 +171,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:

@classmethod
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
"""
Decode Inference requst into dictionary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Decode Inference requst into dictionary
Decode Inference request into dictionary

extra Inference kwargs can be kept in 'InferenceRequest.parameters.extra'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extra Inference kwargs can be kept in 'InferenceRequest.parameters.extra'
extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra'

"""
values = {}
field_codecs = cls._find_decode_codecs(request)
for item in request.inputs:
Expand All @@ -181,6 +186,15 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value
if request.parameters is not None:
sakoush marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(request.parameters, "extra"):
extra = request.parameters.extra
if isinstance(extra, dict):
values.update(extra)
sakoush marked this conversation as resolved.
Show resolved Hide resolved
else:
logging.warn(
"Extra inference kwargs should be kept in a dictionary."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you output the value of the parameter as well in the warning message? And perhaps change the warning message to be in the form of "Extra parameters cannot be parsed, expected a dictionary" to be more descriptive in the message?

)
return values


Expand Down
20 changes: 18 additions & 2 deletions runtimes/huggingface/tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,28 @@
]
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"]},
)
),
(
sakoush marked this conversation as resolved.
Show resolved Hide resolved
InferenceRequest(
parameters=Parameters(content_type="str", extra={"foo3": "var2"}),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"], "foo3": "var2"},
),
],
)
def test_decode_request(inference_request, expected):
payload = HuggingfaceRequestCodec.decode_request(inference_request)

assert payload == expected


Expand Down
47 changes: 47 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch

import json
import pytest
import torch
from typing import Dict, Optional
Expand All @@ -13,6 +14,8 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import HuggingFaceSettings
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.types.dataplane import Parameters


@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,3 +213,47 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs, expected_num_tokens",
[
({"max_new_tokens": 10, "return_full_text": False}, 10),
({"max_new_tokens": 20, "return_full_text": False}, 20),
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs: Optional[dict],
expected_num_tokens: int,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
sakoush marked this conversation as resolved.
Show resolved Hide resolved
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs),
)
tokenizer = runtime._model.tokenizer

prediction = await runtime.predict(payload)
generated_text = json.loads(prediction.outputs[0].data[0])["generated_text"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try to use the hf codec decode_response method and check if it makes this line a bit more readable?

assert isinstance(generated_text, str)
tokenized_generated_text = tokenizer.tokenize(generated_text)
num_predicted_tokens = len(tokenized_generated_text)
assert num_predicted_tokens == expected_num_tokens