diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index 60aa216..0f19f56 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -480,7 +480,7 @@ def run(app_root, handler, lambda_runtime_api_addr): if error_result is not None: log_error(error_result, log_sink) - lambda_runtime_client.post_init_error(to_json(error_result)) + lambda_runtime_client.post_init_error(error_result) sys.exit(1) diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index 07243fc..036d10b 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -5,6 +5,9 @@ import sys from awslambdaric import __version__ from .lambda_runtime_exception import FaultException +from .lambda_runtime_marshaller import to_json + +ERROR_TYPE_HEADER = "Lambda-Runtime-Function-Error-Type" def _user_agent(): @@ -68,7 +71,10 @@ def post_init_error(self, error_response_data): runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address) runtime_connection.connect() endpoint = "/2018-06-01/runtime/init/error" - runtime_connection.request("POST", endpoint, error_response_data) + headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]} + runtime_connection.request( + "POST", endpoint, to_json(error_response_data), headers=headers + ) response = runtime_connection.getresponse() response_body = response.read() diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py index b0eae4a..e09130b 100644 --- a/tests/test_lambda_runtime_client.py +++ b/tests/test_lambda_runtime_client.py @@ -14,6 +14,7 @@ LambdaRuntimeClientError, _user_agent, ) +from awslambdaric.lambda_runtime_marshaller import to_json class TestInvocationRequest(unittest.TestCase): @@ -99,6 +100,15 @@ def test_wait_next_invocation(self, mock_runtime_client): self.assertEqual(event_request.content_type, "application/json") self.assertEqual(event_request.event_body, response_body) + error_result = { + "errorMessage": "Dummy message", + "errorType": "Runtime.DummyError", + "requestId": "", + "stackTrace": [], + } + + headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]} + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) def test_post_init_error(self, MockHTTPConnection): mock_conn = MockHTTPConnection.return_value @@ -108,11 +118,14 @@ def test_post_init_error(self, MockHTTPConnection): mock_response.code = http.HTTPStatus.ACCEPTED runtime_client = LambdaRuntimeClient("localhost:1234") - runtime_client.post_init_error("error_data") + runtime_client.post_init_error(self.error_result) MockHTTPConnection.assert_called_with("localhost:1234") mock_conn.request.assert_called_once_with( - "POST", "/2018-06-01/runtime/init/error", "error_data" + "POST", + "/2018-06-01/runtime/init/error", + to_json(self.error_result), + headers=self.headers, ) mock_response.read.assert_called_once() @@ -127,7 +140,7 @@ def test_post_init_error_non_accepted_status_code(self, MockHTTPConnection): runtime_client = LambdaRuntimeClient("localhost:1234") with self.assertRaises(LambdaRuntimeClientError) as cm: - runtime_client.post_init_error("error_data") + runtime_client.post_init_error(self.error_result) returned_exception = cm.exception self.assertEqual(returned_exception.endpoint, "/2018-06-01/runtime/init/error") @@ -215,12 +228,12 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie def test_connection_refused(self): with self.assertRaises(ConnectionRefusedError): runtime_client = LambdaRuntimeClient("127.0.0.1:1") - runtime_client.post_init_error("error") + runtime_client.post_init_error(self.error_result) def test_invalid_addr(self): with self.assertRaises(OSError): runtime_client = LambdaRuntimeClient("::::") - runtime_client.post_init_error("error") + runtime_client.post_init_error(self.error_result) def test_lambdaric_version(self): self.assertTrue(_user_agent().endswith(__version__))