Skip to content

Commit

Permalink
MMS - Support for custom error codes in custom handlers (#965)
Browse files Browse the repository at this point in the history
Support for custom error codes
  • Loading branch information
dhanainme authored Apr 26, 2021
1 parent c58e29b commit aa7c230
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docs/custom_service.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ Here the ``` handle()``` method is our entry point that will be invoked by MMS,

This entry point is engaged in two cases: (1) when MMS is asked to scale a model up, to increase the number of backend workers (it is done either via a ```PUT /models/{model_name}``` request or a ```POST /models``` request with `initial-workers` option or during MMS startup when you use `--models` option (```multi-model-server --start --models {model_name=model.mar}```), ie., you provide model(s) to load) or (2) when MMS gets a ```POST /predictions/{model_name}``` request. (1) is used to scale-up or scale-down workers for a model. (2) is used as a standard way to run inference against a model. (1) is also known as model load time, and that is where you would normally want to put code for model initialization. You can find out more about these and other MMS APIs in [MMS Management API](./management_api.md) and [MMS Inference API](./inference_api.md)


### Returning custom error codes

To return a custom error code back to the user use the `PredictionException` in the `mms.service` module.

```python
from mms.service import PredictionException
def handler(data, context):
# Some unexpected error - returning error code 513
raise PredictionException("Some Prediction Error", 513)
```

## Creating model archive with entry point

MMS, identifies the entry point to the custom service, from the manifest file. Thus file creating the model archive, one needs to mention the entry point using the ```--handler``` option.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"specificationVersion": "1.0",
"implementationVersion": "1.0",
"description": "noop v1.0",
"modelServerVersion": "1.0",
"license": "Apache 2.0",
"runtime": "python",
"model": {
"modelName": "pred-custom-return-code",
"description": "Tests for custom return code",
"modelVersion": "1.0",
"handler": "service:handle"
},
"publisher": {
"author": "MXNet SDK team",
"email": "[email protected]"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# http://www.apache.org/licenses/LICENSE-2.0
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from mms.service import PredictionException

def handle(data, ctx):
# Data is not none in prediction request
# Python raises PredictionException with custom error code
if data is not None:
raise PredictionException("Some Prediction Error", 599)
return ["OK"]
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ public void test()
testLoggingUnload(channel, managementChannel);
testLoadingMemoryError();
testPredictionMemoryError();
testPredictionCustomErrorCode();
testMetricManager();
testErrorBatch();

Expand Down Expand Up @@ -1189,6 +1190,51 @@ private void testPredictionMemoryError() throws InterruptedException {
Assert.assertEquals(httpStatus, HttpResponseStatus.OK);
}

private void testPredictionCustomErrorCode() throws InterruptedException {
// Load the model
Channel channel = connect(true);
Assert.assertNotNull(channel);
result = null;
latch = new CountDownLatch(1);
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
HttpMethod.POST,
"/models?url=custom-return-code&model_name=custom-return-code&runtime=python&initial_workers=1&synchronous=true");
channel.writeAndFlush(req);
latch.await();
Assert.assertEquals(httpStatus, HttpResponseStatus.OK);
channel.close();

// Test for prediction
channel = connect(false);
Assert.assertNotNull(channel);
result = null;
latch = new CountDownLatch(1);
req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/custom-return-code");
req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);

channel.writeAndFlush(req);
latch.await();

Assert.assertEquals(httpStatus.code(), 599);
channel.close();

// Unload the model
channel = connect(true);
httpStatus = null;
latch = new CountDownLatch(1);
Assert.assertNotNull(channel);
req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/custom-return-code");
channel.writeAndFlush(req);
latch.await();
Assert.assertEquals(httpStatus, HttpResponseStatus.OK);
}

private void testErrorBatch() throws InterruptedException {
Channel channel = connect(true);
Assert.assertNotNull(channel);
Expand Down
13 changes: 13 additions & 0 deletions mms/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def predict(self, batch):
# noinspection PyBroadException
try:
ret = self._entry_point(input_batch, self.context)
except PredictionException as e:
logger.error("Prediction error", exc_info=True)
return create_predict_response(None, req_id_map, e.message, e.error_code)
except MemoryError:
logger.error("System out of memory", exc_info=True)
return create_predict_response(None, req_id_map, "Out of resources", 507)
Expand All @@ -128,6 +131,16 @@ def predict(self, batch):
return create_predict_response(ret, req_id_map, "Prediction success", 200, context=self.context)


class PredictionException(Exception):
def __init__(self, message, error_code=500):
self.message = message
self.error_code = error_code
super(PredictionException, self).__init__(message)

def __str__(self):
return "message : error_code".format(message=self.message, error_code=self.error_code)


def emit_metrics(metrics):
"""
Emit the metrics in the provided Dictionary
Expand Down

0 comments on commit aa7c230

Please sign in to comment.