Skip to content

Commit

Permalink
Add structured logging implementation (#101)
Browse files Browse the repository at this point in the history
Co-authored-by: Artem Krivonos <[email protected]>
  • Loading branch information
kotyara1005 and Artem Krivonos authored Aug 22, 2023
1 parent d53a9bc commit 8b34c48
Show file tree
Hide file tree
Showing 5 changed files with 436 additions and 55 deletions.
140 changes: 97 additions & 43 deletions awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@
from .lambda_context import LambdaContext
from .lambda_runtime_client import LambdaRuntimeClient
from .lambda_runtime_exception import FaultException
from .lambda_runtime_log_utils import (
_DATETIME_FORMAT,
_DEFAULT_FRAME_TYPE,
_JSON_FRAME_TYPES,
JsonFormatter,
LogFormat,
)
from .lambda_runtime_marshaller import to_json

ERROR_LOG_LINE_TERMINATE = "\r"
ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0
_AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT"))
_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper()


def _get_handler(handler):
Expand Down Expand Up @@ -73,7 +82,12 @@ def result(*args):
return result


def make_error(error_message, error_type, stack_trace, invoke_id=None):
def make_error(
error_message,
error_type,
stack_trace,
invoke_id=None,
):
result = {
"errorMessage": error_message if error_message else "",
"errorType": error_type if error_type else "",
Expand All @@ -92,34 +106,52 @@ def replace_line_indentation(line, indent_char, new_indent_char):
return (new_indent_char * ident_chars_count) + line[ident_chars_count:]


def log_error(error_result, log_sink):
error_description = "[ERROR]"
if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON:
_ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR]

def log_error(error_result, log_sink):
error_result = {
"timestamp": time.strftime(
_DATETIME_FORMAT, logging.Formatter.converter(time.time())
),
"log_level": "ERROR",
**error_result,
}
log_sink.log_error(
[to_json(error_result)],
)

error_result_type = error_result.get("errorType")
if error_result_type:
error_description += " " + error_result_type
else:
_ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE

error_result_message = error_result.get("errorMessage")
if error_result_message:
def log_error(error_result, log_sink):
error_description = "[ERROR]"

error_result_type = error_result.get("errorType")
if error_result_type:
error_description += ":"
error_description += " " + error_result_message
error_description += " " + error_result_type

error_result_message = error_result.get("errorMessage")
if error_result_message:
if error_result_type:
error_description += ":"
error_description += " " + error_result_message

error_message_lines = [error_description]
error_message_lines = [error_description]

stack_trace = error_result.get("stackTrace")
if stack_trace is not None:
error_message_lines += ["Traceback (most recent call last):"]
for trace_element in stack_trace:
if trace_element == "":
error_message_lines += [""]
else:
for trace_line in trace_element.splitlines():
error_message_lines += [
replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
]
stack_trace = error_result.get("stackTrace")
if stack_trace is not None:
error_message_lines += ["Traceback (most recent call last):"]
for trace_element in stack_trace:
if trace_element == "":
error_message_lines += [""]
else:
for trace_line in trace_element.splitlines():
error_message_lines += [
replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
]

log_sink.log_error(error_message_lines)
log_sink.log_error(error_message_lines)


def handle_event_request(
Expand Down Expand Up @@ -152,7 +184,12 @@ def handle_event_request(
)
except FaultException as e:
xray_fault = make_xray_fault("LambdaValidationError", e.msg, os.getcwd(), [])
error_result = make_error(e.msg, e.exception_type, e.trace, invoke_id)
error_result = make_error(
e.msg,
e.exception_type,
e.trace,
invoke_id,
)

except Exception:
etype, value, tb = sys.exc_info()
Expand Down Expand Up @@ -221,7 +258,9 @@ def build_fault_result(exc_info, msg):
break

return make_error(
msg if msg else str(value), etype.__name__, traceback.format_list(tb_tuples)
msg if msg else str(value),
etype.__name__,
traceback.format_list(tb_tuples),
)


Expand Down Expand Up @@ -257,7 +296,8 @@ def __init__(self, log_sink):

def emit(self, record):
msg = self.format(record)
self.log_sink.log(msg)

self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None))


class LambdaLoggerFilter(logging.Filter):
Expand Down Expand Up @@ -298,7 +338,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_tb):
pass

def log(self, msg):
def log(self, msg, frame_type=None):
sys.stdout.write(msg)

def log_error(self, message_lines):
Expand All @@ -324,7 +364,6 @@ class FramedTelemetryLogSink(object):

def __init__(self, fd):
self.fd = int(fd)
self.frame_type = 0xA55A0003.to_bytes(4, "big")

def __enter__(self):
self.file = os.fdopen(self.fd, "wb", 0)
Expand All @@ -333,11 +372,12 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_tb):
self.file.close()

def log(self, msg):
def log(self, msg, frame_type=None):
encoded_msg = msg.encode("utf8")

timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds
log_msg = (
self.frame_type
(frame_type or _DEFAULT_FRAME_TYPE)
+ len(encoded_msg).to_bytes(4, "big")
+ timestamp.to_bytes(8, "big")
+ encoded_msg
Expand All @@ -346,7 +386,10 @@ def log(self, msg):

def log_error(self, message_lines):
error_message = "\n".join(message_lines)
self.log(error_message)
self.log(
error_message,
frame_type=_ERROR_FRAME_TYPE,
)


def update_xray_env_variable(xray_trace_id):
Expand All @@ -370,6 +413,28 @@ def create_log_sink():
_GLOBAL_AWS_REQUEST_ID = None


def _setup_logging(log_format, log_level, log_sink):
logging.Formatter.converter = time.gmtime
logger = logging.getLogger()
logger_handler = LambdaLoggerHandler(log_sink)
if log_format == LogFormat.JSON:
logger_handler.setFormatter(JsonFormatter())

logging.addLevelName(logging.DEBUG, "TRACE")
if log_level in logging._nameToLevel:
logger.setLevel(log_level)
else:
logger_handler.setFormatter(
logging.Formatter(
"[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
"%Y-%m-%dT%H:%M:%S",
)
)

logger_handler.addFilter(LambdaLoggerFilter())
logger.addHandler(logger_handler)


def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)
Expand All @@ -378,18 +443,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)

try:
logging.Formatter.converter = time.gmtime
logger = logging.getLogger()
logger_handler = LambdaLoggerHandler(log_sink)
logger_handler.setFormatter(
logging.Formatter(
"[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
"%Y-%m-%dT%H:%M:%S",
)
)
logger_handler.addFilter(LambdaLoggerFilter())
logger.addHandler(logger_handler)

_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
global _GLOBAL_AWS_REQUEST_ID

request_handler = _get_handler(handler)
Expand Down
123 changes: 123 additions & 0 deletions awslambdaric/lambda_runtime_log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
"""

import json
import logging
import traceback
from enum import IntEnum

_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
_RESERVED_FIELDS = {
"name",
"msg",
"args",
"levelname",
"levelno",
"pathname",
"filename",
"module",
"exc_info",
"exc_text",
"stack_info",
"lineno",
"funcName",
"created",
"msecs",
"relativeCreated",
"thread",
"threadName",
"processName",
"process",
"aws_request_id",
"_frame_type",
}


class LogFormat(IntEnum):
JSON = 0b0
TEXT = 0b1

@classmethod
def from_str(cls, value: str):
if value and value.upper() == "JSON":
return cls.JSON.value
return cls.TEXT.value


_JSON_FRAME_TYPES = {
logging.NOTSET: 0xA55A0002.to_bytes(4, "big"),
logging.DEBUG: 0xA55A000A.to_bytes(4, "big"),
logging.INFO: 0xA55A000E.to_bytes(4, "big"),
logging.WARNING: 0xA55A0012.to_bytes(4, "big"),
logging.ERROR: 0xA55A0016.to_bytes(4, "big"),
logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"),
}
_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big")

_json_encoder = json.JSONEncoder(ensure_ascii=False)
_encode_json = _json_encoder.encode


class JsonFormatter(logging.Formatter):
def __init__(self):
super().__init__(datefmt=_DATETIME_FORMAT)

@staticmethod
def __format_stacktrace(exc_info):
if not exc_info:
return None
return traceback.format_tb(exc_info[2])

@staticmethod
def __format_exception_name(exc_info):
if not exc_info:
return None

return exc_info[0].__name__

@staticmethod
def __format_exception(exc_info):
if not exc_info:
return None

return str(exc_info[1])

@staticmethod
def __format_location(record: logging.LogRecord):
if not record.exc_info:
return None

return f"{record.pathname}:{record.funcName}:{record.lineno}"

@staticmethod
def __format_log_level(record: logging.LogRecord):
record.levelno = min(50, max(0, record.levelno)) // 10 * 10
record.levelname = logging.getLevelName(record.levelno)

def format(self, record: logging.LogRecord) -> str:
self.__format_log_level(record)
record._frame_type = _JSON_FRAME_TYPES.get(
record.levelno, _JSON_FRAME_TYPES[logging.NOTSET]
)

result = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"logger": record.name,
"stackTrace": self.__format_stacktrace(record.exc_info),
"errorType": self.__format_exception_name(record.exc_info),
"errorMessage": self.__format_exception(record.exc_info),
"requestId": getattr(record, "aws_request_id", None),
"location": self.__format_location(record),
}
result.update(
(key, value)
for key, value in record.__dict__.items()
if key not in _RESERVED_FIELDS and key not in result
)

result = {k: v for k, v in result.items() if v is not None}

return _encode_json(result) + "\n"
Loading

0 comments on commit 8b34c48

Please sign in to comment.