Skip to content

Commit

Permalink
Propogate error type in header when reporting init error to RAPID (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivenay authored Aug 8, 2024
1 parent f9d370f commit a37a43a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def run(app_root, handler, lambda_runtime_api_addr):

if error_result is not None:
log_error(error_result, log_sink)
lambda_runtime_client.post_init_error(to_json(error_result))
lambda_runtime_client.post_init_error(error_result)

sys.exit(1)

Expand Down
8 changes: 7 additions & 1 deletion awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import sys
from awslambdaric import __version__
from .lambda_runtime_exception import FaultException
from .lambda_runtime_marshaller import to_json

ERROR_TYPE_HEADER = "Lambda-Runtime-Function-Error-Type"


def _user_agent():
Expand Down Expand Up @@ -68,7 +71,10 @@ def post_init_error(self, error_response_data):
runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address)
runtime_connection.connect()
endpoint = "/2018-06-01/runtime/init/error"
runtime_connection.request("POST", endpoint, error_response_data)
headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]}
runtime_connection.request(
"POST", endpoint, to_json(error_response_data), headers=headers
)
response = runtime_connection.getresponse()
response_body = response.read()

Expand Down
23 changes: 18 additions & 5 deletions tests/test_lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LambdaRuntimeClientError,
_user_agent,
)
from awslambdaric.lambda_runtime_marshaller import to_json


class TestInvocationRequest(unittest.TestCase):
Expand Down Expand Up @@ -99,6 +100,15 @@ def test_wait_next_invocation(self, mock_runtime_client):
self.assertEqual(event_request.content_type, "application/json")
self.assertEqual(event_request.event_body, response_body)

error_result = {
"errorMessage": "Dummy message",
"errorType": "Runtime.DummyError",
"requestId": "",
"stackTrace": [],
}

headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]}

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_post_init_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
Expand All @@ -108,11 +118,14 @@ def test_post_init_error(self, MockHTTPConnection):
mock_response.code = http.HTTPStatus.ACCEPTED

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.post_init_error("error_data")
runtime_client.post_init_error(self.error_result)

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"POST", "/2018-06-01/runtime/init/error", "error_data"
"POST",
"/2018-06-01/runtime/init/error",
to_json(self.error_result),
headers=self.headers,
)
mock_response.read.assert_called_once()

Expand All @@ -127,7 +140,7 @@ def test_post_init_error_non_accepted_status_code(self, MockHTTPConnection):
runtime_client = LambdaRuntimeClient("localhost:1234")

with self.assertRaises(LambdaRuntimeClientError) as cm:
runtime_client.post_init_error("error_data")
runtime_client.post_init_error(self.error_result)
returned_exception = cm.exception

self.assertEqual(returned_exception.endpoint, "/2018-06-01/runtime/init/error")
Expand Down Expand Up @@ -215,12 +228,12 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie
def test_connection_refused(self):
with self.assertRaises(ConnectionRefusedError):
runtime_client = LambdaRuntimeClient("127.0.0.1:1")
runtime_client.post_init_error("error")
runtime_client.post_init_error(self.error_result)

def test_invalid_addr(self):
with self.assertRaises(OSError):
runtime_client = LambdaRuntimeClient("::::")
runtime_client.post_init_error("error")
runtime_client.post_init_error(self.error_result)

def test_lambdaric_version(self):
self.assertTrue(_user_agent().endswith(__version__))
Expand Down

0 comments on commit a37a43a

Please sign in to comment.