Skip to content

Commit

Permalink
Add rate limiting; Refactor FastAPI code
Browse files Browse the repository at this point in the history
  • Loading branch information
mahesh-maan committed Sep 25, 2024
1 parent f64f5c4 commit 810df80
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 170 deletions.
78 changes: 0 additions & 78 deletions auth.py

This file was deleted.

151 changes: 151 additions & 0 deletions middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import re
import time
import asyncio
from datetime import datetime
from collections import deque
import logging
from logging.handlers import TimedRotatingFileHandler

from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

from config import config
from routes import routes_config

logger = logging.getLogger('API-ACCESS')
logger.setLevel(logging.DEBUG)

fh = TimedRotatingFileHandler(
'api-access.log',
when='midnight',
interval=1,
backupCount=7
)
fh.setLevel(logging.INFO)
fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(fh)

class CustomLogMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
ip = request.client.host
route = request.url.path

t0 = datetime.now()
response = await call_next(request)
dt = (datetime.now() - t0).total_seconds()

log_message = (
f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} "
f"{ip} "
f"{request.method} "
f"{route} "
f"{response.status_code} "
f"{dt:.2f}s"
)
logger.info(log_message)
return response


class AuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, tokens_file):
super().__init__(app)
self.tokens_file = tokens_file
self.tokens = set()
self.read_tokens()

async def dispatch(self, request: Request, call_next):
if not config.token_authentication_active:
return await call_next(request)

route = request.url.path

route_config = next((r for r in routes_config if r["path"] == route), None)
if not route_config:
return await call_next(request)

is_protected = route_config.get("protected", True)
if not is_protected:
return await call_next(request)

token = await self.extract_token(request)
if token is None or token not in self.tokens:
logger.info("%s - No token", route)
return JSONResponse(
status_code=401,
content={"detail": "Unauthorized"}
)

logger.info("%s - Valid token: %s", route, token)
return await call_next(request)

@staticmethod
async def extract_token(req: Request):
method = req.method
if method == "GET":
return req.query_params.get("token")
if method == "POST":
try:
json_body = await req.json()
return json_body.get("token")
except Exception as e:
logger.error("Error extracting token from POST body: %s", e)
return None
return None

def read_tokens(self):
if not os.path.isfile(self.tokens_file):
return
with open(self.tokens_file, "r", encoding="utf-8") as f:
lines = f.read().strip().splitlines()
lines = [l for l in lines if not l.startswith("#") and l.strip()]
tokens = [re.split(r'\s+', l)[0] for l in lines]
self.tokens.update(tokens)


class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, default_limit: int, window: int, routes_limits: dict = None):
super().__init__(app)
self.default_limit = default_limit # request volume per time window per client
self.window = window # window duration in seconds
self.request_log = {} # tracks request counts
self.lock = asyncio.Lock()

async def dispatch(self, request: Request, call_next):
route = request.url.path
route_config = next((r for r in routes_config if r["path"] == route), None)

if route_config is None:
return await call_next(request)

limit = route_config.get("rateLimit", self.default_limit)
if limit == -1:
return await call_next(request)

client_id = await AuthMiddleware.extract_token(request)

current_time = time.monotonic()
key = (client_id, route)

async with self.lock:
request_times = self.request_log.setdefault(key, deque())

while request_times and request_times[0] <= current_time - self.window:
request_times.popleft()

if len(request_times) >= limit:
retry_after = self.window - (current_time - request_times[0])
return JSONResponse(
status_code=429,
content={"detail": "Too many requests"},
headers={"Retry-After": str(int(retry_after))}
)

request_times.append(current_time)

# Clean up to prevent memory leak
if not request_times:
del self.requests_log[key]

return await call_next(request)
40 changes: 35 additions & 5 deletions plugins/miniapps/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,41 @@
import api as API

route_config = [
{"method": "GET", "path": "/suggest/cpcs", "handler": API.SuggestCPCs},
{"method": "GET", "path": "/predict/gaus", "handler": API.PredictGAUs},
{"method": "GET", "path": "/suggest/synonyms", "handler": API.SuggestSynonyms},
{"method": "GET", "path": "/extract/concepts", "handler": API.ExtractConcepts},
{"method": "GET", "path": "/definitions/cpcs", "handler": API.DefineCPC}
{
"method": "GET",
"path": "/suggest/cpcs",
"handler": API.SuggestCPCs,
"rateLimit": 5,
"protected": True
},
{
"method": "GET",
"path": "/predict/gaus",
"handler": API.PredictGAUs,
"rateLimit": 5,
"protected": True
},
{
"method": "GET",
"path": "/suggest/synonyms",
"handler": API.SuggestSynonyms,
"rateLimit": 5,
"protected": True
},
{
"method": "GET",
"path": "/extract/concepts",
"handler": API.ExtractConcepts,
"rateLimit": 5,
"protected": True
},
{
"method": "GET",
"path": "/definitions/cpcs",
"handler": API.DefineCPC,
"rateLimit": -1,
"protected": True
}
]

add_routes(app, route_config)
Loading

0 comments on commit 810df80

Please sign in to comment.