-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow payload request to support extra inference method kwargs #1505
Changes from 14 commits
679cd0a
83fe202
70f0de8
faffcb5
2b8cf1b
e7b3123
41e6aa0
86238ce
c852edb
a221f7b
6d40906
13c274b
36e4f9d
bd8c293
a2307d2
ad435a7
5693231
4f6c89e
344aaf9
44cd3e7
381701f
c2fc011
31362dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,3 +1,4 @@ | ||||||
import logging | ||||||
from typing import Optional, Type, Any, Dict, List, Union, Sequence | ||||||
from mlserver.codecs.utils import ( | ||||||
has_decoded, | ||||||
|
@@ -170,6 +171,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest: | |||||
|
||||||
@classmethod | ||||||
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: | ||||||
""" | ||||||
Decode Inference requst into dictionary | ||||||
extra Inference kwargs can be kept in 'InferenceRequest.parameters.extra' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
values = {} | ||||||
field_codecs = cls._find_decode_codecs(request) | ||||||
for item in request.inputs: | ||||||
|
@@ -181,6 +186,15 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: | |||||
|
||||||
value = get_decoded_or_raw(item) | ||||||
values[item.name] = value | ||||||
if request.parameters is not None: | ||||||
sakoush marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if hasattr(request.parameters, "extra"): | ||||||
extra = request.parameters.extra | ||||||
if isinstance(extra, dict): | ||||||
values.update(extra) | ||||||
sakoush marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
else: | ||||||
logging.warn( | ||||||
"Extra inference kwargs should be kept in a dictionary." | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you output the value of the parameter as well in the warning message? And perhaps change the warning message to be in the form of "Extra parameters cannot be parsed, expected a dictionary" to be more descriptive in the message? |
||||||
) | ||||||
return values | ||||||
|
||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import json | ||
import pytest | ||
import torch | ||
from typing import Dict, Optional | ||
|
@@ -13,6 +14,8 @@ | |
from mlserver_huggingface.runtime import HuggingFaceRuntime | ||
from mlserver_huggingface.settings import HuggingFaceSettings | ||
from mlserver_huggingface.common import load_pipeline_from_settings | ||
from mlserver.types import InferenceRequest, RequestInput | ||
from mlserver.types.dataplane import Parameters | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -210,3 +213,47 @@ def test_pipeline_checks_for_eos_and_pad_token( | |
m = load_pipeline_from_settings(hf_settings, model_settings) | ||
|
||
assert m._batch_size == expected_batch_size | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"inference_kwargs, expected_num_tokens", | ||
[ | ||
({"max_new_tokens": 10, "return_full_text": False}, 10), | ||
({"max_new_tokens": 20, "return_full_text": False}, 20), | ||
], | ||
) | ||
async def test_pipeline_uses_inference_kwargs( | ||
inference_kwargs: Optional[dict], | ||
expected_num_tokens: int, | ||
): | ||
model_settings = ModelSettings( | ||
name="foo", | ||
implementation=HuggingFaceRuntime, | ||
parameters=ModelParameters( | ||
extra={ | ||
"pretrained_model": "Maykeye/TinyLLama-v0", | ||
"task": "text-generation", | ||
sakoush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
), | ||
) | ||
runtime = HuggingFaceRuntime(model_settings) | ||
runtime.ready = await runtime.load() | ||
payload = InferenceRequest( | ||
inputs=[ | ||
RequestInput( | ||
name="args", | ||
shape=[1], | ||
datatype="BYTES", | ||
data=["This is a test"], | ||
) | ||
], | ||
parameters=Parameters(extra=inference_kwargs), | ||
) | ||
tokenizer = runtime._model.tokenizer | ||
|
||
prediction = await runtime.predict(payload) | ||
generated_text = json.loads(prediction.outputs[0].data[0])["generated_text"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you try to use the hf codec |
||
assert isinstance(generated_text, str) | ||
tokenized_generated_text = tokenizer.tokenize(generated_text) | ||
num_predicted_tokens = len(tokenized_generated_text) | ||
assert num_predicted_tokens == expected_num_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.