-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rate limiting; Refactor FastAPI code
- Loading branch information
1 parent
f64f5c4
commit 810df80
Showing
5 changed files
with
448 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
import re | ||
import time | ||
import asyncio | ||
from datetime import datetime | ||
from collections import deque | ||
import logging | ||
from logging.handlers import TimedRotatingFileHandler | ||
|
||
from fastapi import Request | ||
from fastapi.responses import JSONResponse | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
|
||
from config import config | ||
from routes import routes_config | ||
|
||
logger = logging.getLogger('API-ACCESS') | ||
logger.setLevel(logging.DEBUG) | ||
|
||
fh = TimedRotatingFileHandler( | ||
'api-access.log', | ||
when='midnight', | ||
interval=1, | ||
backupCount=7 | ||
) | ||
fh.setLevel(logging.INFO) | ||
fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | ||
logger.addHandler(fh) | ||
|
||
class CustomLogMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
ip = request.client.host | ||
route = request.url.path | ||
|
||
t0 = datetime.now() | ||
response = await call_next(request) | ||
dt = (datetime.now() - t0).total_seconds() | ||
|
||
log_message = ( | ||
f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} " | ||
f"{ip} " | ||
f"{request.method} " | ||
f"{route} " | ||
f"{response.status_code} " | ||
f"{dt:.2f}s" | ||
) | ||
logger.info(log_message) | ||
return response | ||
|
||
|
||
class AuthMiddleware(BaseHTTPMiddleware): | ||
def __init__(self, app, tokens_file): | ||
super().__init__(app) | ||
self.tokens_file = tokens_file | ||
self.tokens = set() | ||
self.read_tokens() | ||
|
||
async def dispatch(self, request: Request, call_next): | ||
if not config.token_authentication_active: | ||
return await call_next(request) | ||
|
||
route = request.url.path | ||
|
||
route_config = next((r for r in routes_config if r["path"] == route), None) | ||
if not route_config: | ||
return await call_next(request) | ||
|
||
is_protected = route_config.get("protected", True) | ||
if not is_protected: | ||
return await call_next(request) | ||
|
||
token = await self.extract_token(request) | ||
if token is None or token not in self.tokens: | ||
logger.info("%s - No token", route) | ||
return JSONResponse( | ||
status_code=401, | ||
content={"detail": "Unauthorized"} | ||
) | ||
|
||
logger.info("%s - Valid token: %s", route, token) | ||
return await call_next(request) | ||
|
||
@staticmethod | ||
async def extract_token(req: Request): | ||
method = req.method | ||
if method == "GET": | ||
return req.query_params.get("token") | ||
if method == "POST": | ||
try: | ||
json_body = await req.json() | ||
return json_body.get("token") | ||
except Exception as e: | ||
logger.error("Error extracting token from POST body: %s", e) | ||
return None | ||
return None | ||
|
||
def read_tokens(self): | ||
if not os.path.isfile(self.tokens_file): | ||
return | ||
with open(self.tokens_file, "r", encoding="utf-8") as f: | ||
lines = f.read().strip().splitlines() | ||
lines = [l for l in lines if not l.startswith("#") and l.strip()] | ||
tokens = [re.split(r'\s+', l)[0] for l in lines] | ||
self.tokens.update(tokens) | ||
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware): | ||
def __init__(self, app, default_limit: int, window: int, routes_limits: dict = None): | ||
super().__init__(app) | ||
self.default_limit = default_limit # request volume per time window per client | ||
self.window = window # window duration in seconds | ||
self.request_log = {} # tracks request counts | ||
self.lock = asyncio.Lock() | ||
|
||
async def dispatch(self, request: Request, call_next): | ||
route = request.url.path | ||
route_config = next((r for r in routes_config if r["path"] == route), None) | ||
|
||
if route_config is None: | ||
return await call_next(request) | ||
|
||
limit = route_config.get("rateLimit", self.default_limit) | ||
if limit == -1: | ||
return await call_next(request) | ||
|
||
client_id = await AuthMiddleware.extract_token(request) | ||
|
||
current_time = time.monotonic() | ||
key = (client_id, route) | ||
|
||
async with self.lock: | ||
request_times = self.request_log.setdefault(key, deque()) | ||
|
||
while request_times and request_times[0] <= current_time - self.window: | ||
request_times.popleft() | ||
|
||
if len(request_times) >= limit: | ||
retry_after = self.window - (current_time - request_times[0]) | ||
return JSONResponse( | ||
status_code=429, | ||
content={"detail": "Too many requests"}, | ||
headers={"Retry-After": str(int(retry_after))} | ||
) | ||
|
||
request_times.append(current_time) | ||
|
||
# Clean up to prevent memory leak | ||
if not request_times: | ||
del self.requests_log[key] | ||
|
||
return await call_next(request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.