Skip to content

Commit

Permalink
Add support for snapstart runtime hooks (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivenay authored Nov 18, 2024
1 parent 079135e commit 349d36a
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 12 deletions.
32 changes: 30 additions & 2 deletions awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
_AWS_LAMBDA_LOG_LEVEL = _get_log_level_from_env_var(
os.environ.get("AWS_LAMBDA_LOG_LEVEL")
)
AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE"
INIT_TYPE_SNAP_START = "snap-start"


def _get_handler(handler):
Expand Down Expand Up @@ -286,6 +288,29 @@ def extract_traceback(tb):
]


def on_init_complete(lambda_runtime_client, log_sink):
from . import lambda_runtime_hooks_runner

try:
lambda_runtime_hooks_runner.run_before_snapshot()
lambda_runtime_client.restore_next()
except:
error_result = build_fault_result(sys.exc_info(), None)
log_error(error_result, log_sink)
lambda_runtime_client.post_init_error(
error_result, FaultException.BEFORE_SNAPSHOT_ERROR
)
sys.exit(64)

try:
lambda_runtime_hooks_runner.run_after_restore()
except:
error_result = build_fault_result(sys.exc_info(), None)
log_error(error_result, log_sink)
lambda_runtime_client.report_restore_error(error_result)
sys.exit(65)


class LambdaLoggerHandler(logging.Handler):
def __init__(self, log_sink):
logging.Handler.__init__(self)
Expand Down Expand Up @@ -454,10 +479,10 @@ def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)

use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in [
use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in {
"AWS_Lambda_python3.12",
"AWS_Lambda_python3.13",
]
}

with create_log_sink() as log_sink:
lambda_runtime_client = LambdaRuntimeClient(
Expand Down Expand Up @@ -485,6 +510,9 @@ def run(app_root, handler, lambda_runtime_api_addr):

sys.exit(1)

if os.environ.get(AWS_LAMBDA_INITIALIZATION_TYPE) == INIT_TYPE_SNAP_START:
on_init_complete(lambda_runtime_client, log_sink)

while True:
event_request = lambda_runtime_client.wait_next_invocation()

Expand Down
50 changes: 41 additions & 9 deletions awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,57 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
# Not defining symbol as global to avoid relying on TPE being imported unconditionally.
self.ThreadPoolExecutor = ThreadPoolExecutor

def post_init_error(self, error_response_data):
def call_rapid(
self, http_method, endpoint, expected_http_code, payload=None, headers=None
):
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
# Importing them lazily to speed up critical path of a common case.
import http
import http.client

runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address)
runtime_connection.connect()
endpoint = "/2018-06-01/runtime/init/error"
headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]}
runtime_connection.request(
"POST", endpoint, to_json(error_response_data), headers=headers
)
if http_method == "GET":
runtime_connection.request(http_method, endpoint)
else:
runtime_connection.request(
http_method, endpoint, to_json(payload), headers=headers
)

response = runtime_connection.getresponse()
response_body = response.read()

if response.code != http.HTTPStatus.ACCEPTED:
if response.code != expected_http_code:
raise LambdaRuntimeClientError(endpoint, response.code, response_body)

def post_init_error(self, error_response_data, error_type_override=None):
import http

endpoint = "/2018-06-01/runtime/init/error"
headers = {
ERROR_TYPE_HEADER: (
error_type_override
if error_type_override
else error_response_data["errorType"]
)
}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers
)

def restore_next(self):
import http

endpoint = "/2018-06-01/runtime/restore/next"
self.call_rapid("GET", endpoint, http.HTTPStatus.OK)

def report_restore_error(self, restore_error_data):
import http

endpoint = "/2018-06-01/runtime/restore/error"
headers = {ERROR_TYPE_HEADER: FaultException.AFTER_RESTORE_ERROR}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers
)

def wait_next_invocation(self):
# Calling runtime_client.next() from a separate thread unblocks the main thread,
# which can then process signals.
Expand Down
2 changes: 2 additions & 0 deletions awslambdaric/lambda_runtime_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class FaultException(Exception):
IMPORT_MODULE_ERROR = "Runtime.ImportModuleError"
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
BEFORE_SNAPSHOT_ERROR = "Runtime.BeforeSnapshotError"
AFTER_RESTORE_ERROR = "Runtime.AfterRestoreError"
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"

Expand Down
18 changes: 18 additions & 0 deletions awslambdaric/lambda_runtime_hooks_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from snapshot_restore_py import get_before_snapshot, get_after_restore


def run_before_snapshot():
before_snapshot_callables = get_before_snapshot()
while before_snapshot_callables:
# Using pop as before checkpoint callables are executed in the reverse order of their registration
func, args, kwargs = before_snapshot_callables.pop()
func(*args, **kwargs)


def run_after_restore():
after_restore_callables = get_after_restore()
for func, args, kwargs in after_restore_callables:
func(*args, **kwargs)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
simplejson>=3.18.4
snapshot-restore-py>=1.0.0
51 changes: 50 additions & 1 deletion tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest
from io import StringIO
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, patch, ANY

import awslambdaric.bootstrap as bootstrap
from awslambdaric.lambda_runtime_exception import FaultException
Expand All @@ -23,6 +23,7 @@
from awslambdaric.lambda_literals import (
lambda_unhandled_exception_warning_message,
)
import snapshot_restore_py


class TestUpdateXrayEnv(unittest.TestCase):
Expand Down Expand Up @@ -1457,5 +1458,53 @@ class TestException(Exception):
mock_sys.exit.assert_called_once_with(1)


class TestOnInitComplete(unittest.TestCase):
def tearDown(self):
# We are accessing private filed for cleaning up
snapshot_restore_py._before_snapshot_registry = []
snapshot_restore_py._after_restore_registry = []

# We are using ANY over here as the main thing we want to test is teh errorType propogation and stack trace generation
error_result = {
"errorMessage": "This is a Dummy type error",
"errorType": "TypeError",
"requestId": "",
"stackTrace": ANY,
}

def raise_type_error(self):
raise TypeError("This is a Dummy type error")

@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
def test_before_snapshot_exception(self, mock_runtime_client):
snapshot_restore_py.register_before_snapshot(self.raise_type_error)

with self.assertRaises(SystemExit) as cm:
bootstrap.on_init_complete(
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
)

self.assertEqual(cm.exception.code, 64)
mock_runtime_client.post_init_error.assert_called_once_with(
self.error_result,
FaultException.BEFORE_SNAPSHOT_ERROR,
)

@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
def test_after_restore_exception(self, mock_runtime_client):
snapshot_restore_py.register_after_restore(self.raise_type_error)

with self.assertRaises(SystemExit) as cm:
bootstrap.on_init_complete(
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
)

self.assertEqual(cm.exception.code, 65)
mock_runtime_client.restore_next.assert_called_once()
mock_runtime_client.report_restore_error.assert_called_once_with(
self.error_result
)


if __name__ == "__main__":
unittest.main()
73 changes: 73 additions & 0 deletions tests/test_lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def test_wait_next_invocation(self, mock_runtime_client):

headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]}

restore_error_result = {
"errorMessage": "Dummy Restore error",
"errorType": "Runtime.DummyRestoreError",
"requestId": "",
"stackTrace": [],
}

restore_error_header = {
"Lambda-Runtime-Function-Error-Type": "Runtime.AfterRestoreError"
}

before_snapshot_error_header = {
"Lambda-Runtime-Function-Error-Type": "Runtime.BeforeSnapshotError"
}

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_post_init_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
Expand Down Expand Up @@ -225,6 +240,64 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie
invoke_id, error_data, ""
)

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_restore_next(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.OK

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.restore_next()

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"GET",
"/2018-06-01/runtime/restore/next",
)
mock_response.read.assert_called_once()

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_restore_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.ACCEPTED

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.report_restore_error(self.restore_error_result)

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"POST",
"/2018-06-01/runtime/restore/error",
to_json(self.restore_error_result),
headers=self.restore_error_header,
)
mock_response.read.assert_called_once()

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_init_before_snapshot_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.ACCEPTED

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.post_init_error(self.error_result, "Runtime.BeforeSnapshotError")

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"POST",
"/2018-06-01/runtime/init/error",
to_json(self.error_result),
headers=self.before_snapshot_error_header,
)
mock_response.read.assert_called_once()

def test_connection_refused(self):
with self.assertRaises(ConnectionRefusedError):
runtime_client = LambdaRuntimeClient("127.0.0.1:1")
Expand Down
65 changes: 65 additions & 0 deletions tests/test_runtime_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import unittest
from unittest.mock import patch, call
from awslambdaric import lambda_runtime_hooks_runner
import snapshot_restore_py


def fun_test1():
print("In function ONE")


def fun_test2():
print("In function TWO")


def fun_with_args_kwargs(x, y, **kwargs):
print("Here are the args:", x, y)
print("Here are the keyword args:", kwargs)


class TestRuntimeHooks(unittest.TestCase):
def tearDown(self):
# We are accessing private filed for cleaning up
snapshot_restore_py._before_snapshot_registry = []
snapshot_restore_py._after_restore_registry = []

@patch("builtins.print")
def test_before_snapshot_execution_order(self, mock_print):
snapshot_restore_py.register_before_snapshot(
fun_with_args_kwargs, 5, 7, arg1="Lambda", arg2="SnapStart"
)
snapshot_restore_py.register_before_snapshot(fun_test2)
snapshot_restore_py.register_before_snapshot(fun_test1)

lambda_runtime_hooks_runner.run_before_snapshot()

calls = []
calls.append(call("In function ONE"))
calls.append(call("In function TWO"))
calls.append(call("Here are the args:", 5, 7))
calls.append(
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
)
self.assertEqual(calls, mock_print.mock_calls)

@patch("builtins.print")
def test_after_restore_execution_order(self, mock_print):
snapshot_restore_py.register_after_restore(
fun_with_args_kwargs, 11, 13, arg1="Lambda", arg2="SnapStart"
)
snapshot_restore_py.register_after_restore(fun_test2)
snapshot_restore_py.register_after_restore(fun_test1)

lambda_runtime_hooks_runner.run_after_restore()

calls = []
calls.append(call("Here are the args:", 11, 13))
calls.append(
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
)
calls.append(call("In function TWO"))
calls.append(call("In function ONE"))
self.assertEqual(calls, mock_print.mock_calls)

0 comments on commit 349d36a

Please sign in to comment.