From ff03fd3faab132ec7420911bdcc0e407a21e7d63 Mon Sep 17 00:00:00 2001 From: Vignesh Rao Date: Tue, 17 Sep 2024 22:08:53 -0500 Subject: [PATCH] Add IP whitelisting functionality Restructure OAuth process Fix lazy check and improve datastructure on rate limiting --- docs/genindex.html | 17 ++++++++---- docs/index.html | 35 ++++++++++-------------- docs/objects.inv | Bin 882 -> 889 bytes docs/searchindex.js | 2 +- vaultapi/auth.py | 9 +++++++ vaultapi/main.py | 10 ++++++- vaultapi/models.py | 60 ++++++++++++++++++++--------------------- vaultapi/rate_limit.py | 54 +++++++++++++++++++------------------ 8 files changed, 102 insertions(+), 85 deletions(-) diff --git a/docs/genindex.html b/docs/genindex.html index a97f144..7b7a176 100644 --- a/docs/genindex.html +++ b/docs/genindex.html @@ -42,7 +42,8 @@

Navigation

Index

- A + _ + | A | C | D | E @@ -61,6 +62,14 @@

Index

| W
+

_

+ + +
+

A

    @@ -131,8 +140,6 @@

    E

    • enable_cors() (in module vaultapi.main) -
    • -
    • endpoints (vaultapi.models.EnvConfig attribute)
    • env (in module vaultapi.models)
    • @@ -271,11 +278,11 @@

      M

      P

        +
      • parse_allowed_origins() (vaultapi.models.EnvConfig class method) +
      • parse_api_secret() (vaultapi.models.EnvConfig class method)
      • parse_apikey() (vaultapi.models.EnvConfig class method) -
      • -
      • parse_endpoints() (vaultapi.models.EnvConfig class method)
      • port (vaultapi.models.EnvConfig attribute)
      • diff --git a/docs/index.html b/docs/index.html index e2528f3..0fd4cbe 100644 --- a/docs/index.html +++ b/docs/index.html @@ -323,11 +323,6 @@

        Modelsdatabase: Union[Path, Path, str]
        -
        -
        -endpoints: Union[Url, List[Url]]
        -
        -
        host: str
        @@ -350,7 +345,7 @@

        Models
        -allowed_origins: List[str]
        +allowed_origins: Union[Url, List[Url]]

        @@ -359,9 +354,9 @@

        Models

        -
        -classmethod parse_endpoints(value: Union[Url, List[Url]]) List[Url]
        -

        Validate endpoints to enable CORS policy.

        +
        +classmethod parse_allowed_origins(value: Union[Url, List[Url]]) List[Url]
        +

        Validate allowed origins to enable CORS policy.

        @@ -372,8 +367,8 @@

        Models
        -classmethod parse_api_secret(value: str | None) str | None
        -

        Parse API secret to validate complexity.

        +classmethod parse_api_secret(value: str) str +

        Parse API secret to Fernet compatible.

        @@ -419,16 +414,8 @@

        Models
        -vaultapi.models.complexity_checker(secret: str, simple: bool = False) None
        +vaultapi.models.complexity_checker(secret: str) None

        Verifies the strength of a secret.

        -
        -
        Parameters:
        -
          -
        • secret – Value of the secret.

        • -
        • simple – Boolean flag to increase complexity.

        • -
        -
        -

        See also

        A secret is considered strong if it at least has:

        @@ -526,10 +513,16 @@

        Payload

        RateLimit

        +
        +
        +vaultapi.rate_limit._get_identifier(request: Request) str
        +

        Generate a unique identifier for the request.

        +
        +
        class vaultapi.rate_limit.RateLimiter(rps: RateLimit)
        -

        Object that implements the RateLimiter functionality.

        +

        Rate limiter for incoming requests.

        >>> RateLimiter
         
        diff --git a/docs/objects.inv b/docs/objects.inv index ed7e6a5ca4c74696a4d94715ab618cd677e69454..67bcf83ddfb94cfd10840bfc70d3a211f1487824 100644 GIT binary patch delta 782 zcmV+p1M&Rw2Kfe%cz?le+b|48Y?Z0*x|U_bA^Vy!QNlj2Y|TIWXUi|y za?#H1^VRyt?Kekl8&Oy$MRqmk-UBV=tfUqK88oFEY9J=Ck$=ex4d{vBKOXBZFSS%@ zUG8I>}d}y_I6ib}3ALF`84o zVOCS!5!)#+gnx1_U!io8)LP9p4Yq`MS=giKr6Lt8fiU5UhzSB(4|ATRIV47;W;-S> z&{6Cex@3IbH?7#+eV6ZLkiH4U3hD|EP zvCCVEGM0Y=W84-H9`-=QH{4_JgWByVdRb2!Rkol-^x7F=CblvW)hagPLe`4ovBq|! zsGz_l8?~dSHec+vp@gCxm$c09fP;mnLsgzzEvL9VVxKw5`(vP>HN?uS8yg?{0ZIvc zf}W)PLo(kENuIl@xU`^o zlzZDimi@YnyUvWLEKD5X1_fIqnevQdQ=*)>1cwO&S^?$$vp_f#eH(qpMdpTtX3(b1 z`x`jawRvfnj`}R&3D@P_k(#(x5?o3H>liX@6Mq3zV;PT$(Ai@hSa_@*n#N5-jLG`aA!I`^~U$^=VCG&^kx>nfaA^#nUx3S zzrX*O3@w`)dlPjrep+B=C!DEnIlk~}4}MU*+u7SQ*t)~EDUN9aR2S7T=%0D!$-iII M9S1D_0f71W@=G6w!T<2J!}wcz;W7+b|G@_dbP+c59*8W)~y|8lXXq1ZbB8EsZTA6v>by#oF~6 zy|&L-#kXNHd8u*y_-UCXlJkbTXVC}E#hSo6=mZTTfz zF50<$zFPmd{pP4`BMM|vWLI{U_BB0kH`AUOD$Cz zxdS8ytZz&qlU%AM)uo17wvY(-u@cxbgC>xLeF4hZn3x%zsQhMq{|H9epZzEudj8|H zrcA_mbM`riKDk2b(A0X&Al0GcXng~Z8V670ce71VaB9q1C;3XWw^D4`E``Z2Msun+ z3^mmq!A^l8lz(&i3Z;{zh?;F0TN2`BVQ)n*6{%PWgbBBam>{6_Fy~2{Lt;c~wqxP~ z9mSrZOUCDQGf4nMsU#Dap21ejnt*7@midZS16Hz(k*K7>WdB_5?DicOZjywe+5oX# zvjQ2P#X)75&{Tz^DM=8KI(;y;J4iXJS-QkDgh}|T)PK(qy)hqc+~3Er1B=&zu_b*^ zG`uumFxj!oTZ%H4e*$CN@ev;OK*Sg8WAKAY^(cB-JsefGphfiB8DS>2G7;4(HsV6o zirXs@cBH7FuuC@THjrAlZJiAz6z#aAWqt=75S|XZ^4w}U#pMxa)Jfi7AqA}=R%YE` zeC!7(C4cY<{(z|6L#PSHkIQf9b#bDma^% No - 401: If authorization is invalid. - 403: If host address is forbidden. """ + if request.client.host not in models.session.allowed_origins: + LOGGER.info( + "Host: %s has been blocked since it is not added to allowed list", + request.client.host, + ) + LOGGER.info(models.session.allowed_origins) + raise exceptions.APIResponse( + status_code=HTTPStatus.FORBIDDEN.real, detail=HTTPStatus.FORBIDDEN.phrase + ) if apikey.credentials.startswith("\\"): auth = bytes(apikey.credentials, "utf-8").decode(encoding="unicode_escape") else: diff --git a/vaultapi/main.py b/vaultapi/main.py index 13bd4fd..e4cc8ab 100644 --- a/vaultapi/main.py +++ b/vaultapi/main.py @@ -21,6 +21,14 @@ def __init__(**kwargs) -> None: models.env = squire.load_env(**kwargs) models.session.fernet = Fernet(models.env.secret) models.database = models.Database(models.env.database) + default_allowed = ("0.0.0.0", "127.0.0.1", "localhost") + if models.env.host in default_allowed: + models.session.allowed_origins.update(default_allowed) + else: + models.session.allowed_origins.add(models.env.host) + for allowed in models.env.allowed_origins: + models.session.allowed_origins.add(allowed.host) + LOGGER.info("Allowed origins: %s", models.session.allowed_origins) def enable_cors() -> None: @@ -30,7 +38,7 @@ def enable_cors() -> None: "http://localhost.com", "https://localhost.com", ] - for website in models.env.endpoints: + for website in models.env.allowed_origins: origins.append(f"http://{website.host}") # noqa: HttpUrlsUsage origins.append(f"https://{website.host}") VaultAPI.add_middleware( diff --git a/vaultapi/models.py b/vaultapi/models.py index 8ffa3f6..e3b54f6 100644 --- a/vaultapi/models.py +++ b/vaultapi/models.py @@ -17,13 +17,9 @@ from pydantic_settings import BaseSettings -def complexity_checker(secret: str, simple: bool = False) -> None: +def complexity_checker(secret: str) -> None: """Verifies the strength of a secret. - Args: - secret: Value of the secret. - simple: Boolean flag to increase complexity. - See Also: A secret is considered strong if it at least has: @@ -36,19 +32,14 @@ def complexity_checker(secret: str, simple: bool = False) -> None: Raises: AssertionError: When at least 1 of the above conditions fail to match. """ - char_limit = 8 if simple else 32 - # calculates the length assert ( - len(secret) >= char_limit - ), f"Minimum secret length is {char_limit}, received {len(secret)}" + len(secret) >= 32 + ), f"secret length must be at least 32, received {len(secret)}" # searches for digits assert re.search(r"\d", secret), "secret must include an integer" - if simple: - return - # searches for uppercase assert re.search( r"[A-Z]", secret @@ -61,7 +52,7 @@ def complexity_checker(secret: str, simple: bool = False) -> None: # searches for symbols assert re.search( - r"[ !#$%&'()*+,-./[\\\]^_`{|}~" + r'"]', secret + r"[ !@#$%^&*()_='+,-./[\\\]`{|}~" + r'"]', secret ), "secret must contain at least one special character" @@ -126,19 +117,28 @@ class EnvConfig(BaseSettings): apikey: str secret: str database: FilePath | NewPath | str = Field("secrets.db", pattern=".*.db$") - endpoints: HttpUrl | List[HttpUrl] = [] host: str = socket.gethostbyname("localhost") or "0.0.0.0" port: PositiveInt = 8080 workers: PositiveInt = 1 log_config: FilePath | Dict[str, Any] | None = None - allowed_origins: List[str] = [] - rate_limit: RateLimit | List[RateLimit] = [] - - @field_validator("endpoints", mode="after", check_fields=True) - def parse_endpoints( + allowed_origins: HttpUrl | List[HttpUrl] = [] + # This is a base rate limit configuration + rate_limit: RateLimit | List[RateLimit] = [ + { + "max_requests": 5, + "seconds": 2, + }, # Burst limit: Prevents excessive load on the server + { + "max_requests": 10, + "seconds": 30, + }, # Sustained limit: Prevents too many trial and errors + ] + + @field_validator("allowed_origins", mode="after", check_fields=True) + def parse_allowed_origins( cls, value: HttpUrl | List[HttpUrl] # noqa: PyMethodParameters ) -> List[HttpUrl]: - """Validate endpoints to enable CORS policy.""" + """Validate allowed origins to enable CORS policy.""" if isinstance(value, list): return value return [value] @@ -148,22 +148,20 @@ def parse_apikey(cls, value: str | None) -> str | None: # noqa: PyMethodParamet """Parse API key to validate complexity.""" if value: try: - complexity_checker(value, True) + complexity_checker(value) except AssertionError as error: raise ValueError(error.__str__()) return value @field_validator("secret", mode="after") - def parse_api_secret( - cls, value: str | None # noqa: PyMethodParameters - ) -> str | None: - """Parse API secret to validate complexity.""" - if value: - try: - complexity_checker(value) - except AssertionError as error: - raise ValueError(error.__str__()) - return value + def parse_api_secret(cls, value: str) -> str: # noqa: PyMethodParameters + """Parse API secret to Fernet compatible.""" + try: + Fernet(value) + except ValueError as error: + exc = f"{error}\n\tConsider using 'vaultapi keygen' command to generate a valid secret." + raise ValueError(exc) + return value @classmethod def from_env_file(cls, env_file: pathlib.Path) -> "EnvConfig": diff --git a/vaultapi/rate_limit.py b/vaultapi/rate_limit.py index d7355a7..e3b68dd 100644 --- a/vaultapi/rate_limit.py +++ b/vaultapi/rate_limit.py @@ -1,14 +1,23 @@ import math import time +from collections import defaultdict from http import HTTPStatus +from threading import Lock from fastapi import HTTPException, Request from . import models +def _get_identifier(request: Request) -> str: + """Generate a unique identifier for the request.""" + if forwarded := request.headers.get("x-forwarded-for"): + return f"{forwarded.split(',')[0]}:{request.url.path}" + return f"{request.client.host}:{request.url.path}" + + class RateLimiter: - """Object that implements the ``RateLimiter`` functionality. + """Rate limiter for incoming requests. >>> RateLimiter @@ -27,13 +36,8 @@ def __init__(self, rps: models.RateLimit): """ self.max_requests = rps.max_requests self.seconds = rps.seconds - self.start_time = time.time() - self.exception = HTTPException( - status_code=HTTPStatus.TOO_MANY_REQUESTS.value, - detail=HTTPStatus.TOO_MANY_REQUESTS.phrase, - # reset headers, which will invalidate auth token - headers={"Retry-After": str(math.ceil(self.seconds))}, - ) + self.locks = defaultdict(Lock) # For thread-safe access + self.requests = defaultdict(list) def init(self, request: Request) -> None: """Checks if the number of calls exceeds the rate limit for the given identifier. @@ -44,23 +48,21 @@ def init(self, request: Request) -> None: Raises: 429: Too many requests. """ - if forwarded := request.headers.get("x-forwarded-for"): - identifier = forwarded.split(",")[0] - else: - identifier = request.client.host - identifier += ":" + request.url.path - + identifier = _get_identifier(request) current_time = time.time() - # Reset if the time window has passed - if current_time - self.start_time > self.seconds: - models.session.rps[identifier] = 1 - self.start_time = current_time - - if models.session.rps.get(identifier): - if models.session.rps[identifier] >= self.max_requests: - raise self.exception - else: - models.session.rps[identifier] += 1 - else: - models.session.rps[identifier] = 1 + with self.locks[identifier]: + # Clean up expired timestamps + self.requests[identifier] = [ + timestamp + for timestamp in self.requests[identifier] + if current_time - timestamp < self.seconds + ] + + if len(self.requests[identifier]) >= self.max_requests: + raise HTTPException( + status_code=HTTPStatus.TOO_MANY_REQUESTS.value, + detail=HTTPStatus.TOO_MANY_REQUESTS.phrase, + headers={"Retry-After": str(math.ceil(self.seconds))}, + ) + self.requests[identifier].append(current_time)