From 595f2bfab23e13c39a1b7970df4ef3a794ece763 Mon Sep 17 00:00:00 2001 From: ff137 Date: Wed, 30 Oct 2024 22:39:11 +0300 Subject: [PATCH] :zap: replace FastAPI `JSONResponse` with `ORJSONResponse` --- app/main.py | 18 ++++++++++-------- app/tests/test_main.py | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/app/main.py b/app/main.py index 27313d002..4b7aa3929 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,7 @@ from aries_cloudcontroller import ApiException from fastapi import FastAPI, Request, Response from fastapi.exceptions import HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import ORJSONResponse from scalar_fastapi import get_scalar_api_reference from app.exceptions import CloudApiException @@ -134,41 +134,43 @@ def read_openapi_yaml() -> Response: @app.exception_handler(Exception) -async def universal_exception_handler(_: Request, exception: Exception) -> JSONResponse: +async def universal_exception_handler( + _: Request, exception: Exception +) -> ORJSONResponse: stacktrace = {"traceback": traceback.format_exc()} if debug else {} if isinstance(exception, CloudApiException): - return JSONResponse( + return ORJSONResponse( content={"detail": exception.detail, **stacktrace}, status_code=exception.status_code, ) if isinstance(exception, CloudApiValueError): - return JSONResponse( + return ORJSONResponse( {"detail": exception.detail, **stacktrace}, status_code=422, ) if isinstance(exception, pydantic.ValidationError): - return JSONResponse( + return ORJSONResponse( {"detail": extract_validation_error_msg(exception), **stacktrace}, status_code=422, ) if isinstance(exception, ApiException): - return JSONResponse( + return ORJSONResponse( {"detail": exception.reason, **stacktrace}, status_code=exception.status, ) if isinstance(exception, HTTPException): - return JSONResponse( + return ORJSONResponse( {"detail": exception.detail, **stacktrace}, status_code=exception.status_code, headers=exception.headers, ) - return JSONResponse( + return ORJSONResponse( {"detail": "Internal server error", "exception": str(exception), **stacktrace}, status_code=500, ) diff --git a/app/tests/test_main.py b/app/tests/test_main.py index 2493f9d5f..7ad8c2278 100644 --- a/app/tests/test_main.py +++ b/app/tests/test_main.py @@ -4,7 +4,7 @@ import pytest from aries_cloudcontroller import ApiException from fastapi import HTTPException, Request -from fastapi.responses import JSONResponse +from fastapi.responses import ORJSONResponse from app.exceptions.cloudapi_exception import CloudApiException from app.main import ( @@ -109,6 +109,6 @@ async def test_universal_exception_handler(): for exception, expected_status, expected_detail in test_cases: request = Mock(spec=Request) response = await universal_exception_handler(request, exception) - assert isinstance(response, JSONResponse) + assert isinstance(response, ORJSONResponse) assert response.status_code == expected_status assert expected_detail in response.body.decode()