Skip to content

Commit

Permalink
Fix accept header parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
avillar committed Aug 26, 2024
1 parent 4669336 commit 97e95a7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
9 changes: 8 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import asynccontextmanager
from typing import Annotated, Union

Expand Down Expand Up @@ -40,10 +41,16 @@ async def lifespan(app: FastAPI):

app.profile_loader = None

logger = logging.getLogger('uvicorn.error')


@app.get('/', response_model=model.LandingPage)
def capabilities(req: Request, accept: Annotated[str | None, Header()] = None) -> model.LandingPage | HTMLResponse:
req_media_type = util.match_accept_header(accept, [MEDIA_TEXT_HTML, MEDIA_APPLICATION_JSON])
try:
req_media_type = util.match_accept_header(accept, [MEDIA_TEXT_HTML, MEDIA_APPLICATION_JSON])
except Exception as e:
req_media_type = MEDIA_TEXT_HTML
logger.error(f'Error parsing Accept header "{accept}"', e)

if req_media_type is None:
raise HTTPException(
Expand Down
24 changes: 22 additions & 2 deletions app/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from pathlib import Path
from typing import Any


def is_xml(s: str) -> bool:
Expand All @@ -18,13 +19,32 @@ def match_media_type(a: list[str], b: list[str]):
return (a[0] == '*' or b[0] == '*' or a[0] == b[0]) and (a[1] == '*' or b[1] == '*' or a[1] == b[1])


def parse_accept_entries(header: str) -> dict[str, Any]:
entries = {}
for e in header.split(','):
parts = e.strip().split(';')
if not parts:
continue
entry = {
'q': 1.0
}
for part in parts[1:]:
kv = part.split('=', 1)
if kv[0] == 'q':
kv[1] = float(kv[1])
entry[kv[0]] = kv[1]
entries[parts[0]] = entry
return entries


def match_accept_header(header: str, options: list[str], def_value=None) -> str:
if not header:
return def_value

options = [o.strip().lower().split('/') for o in options]
for entry in sorted((e.strip().lower().split(';') for e in header.split(',')
if e.strip()), key=lambda e: float(e[1][2:]) if len(e) > 1 else 1.0):
entries = parse_accept_entries(header)
for entry in sorted(entries.items(), key=lambda e: -e[1]['q']):
print(entry)
parts = entry[0].split('/')
for option in options:
if match_media_type(option, parts):
Expand Down

0 comments on commit 97e95a7

Please sign in to comment.