Skip to content

Commit

Permalink
Enable CORS middleware settings
Browse files Browse the repository at this point in the history
Freeze requirements
Add a util script to encode secret
  • Loading branch information
dormant-user committed Sep 17, 2024
1 parent c10f673 commit 0cdcd3b
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 40 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# env-files
API to share env files
# VaultAPI
API to store and retrieve secrets
29 changes: 22 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
cryptography
fastapi
pydantic
pydantic-settings
PyYaml
requests
uvicorn
annotated-types==0.7.0
anyio==4.4.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
cryptography==43.0.1
fastapi==0.114.2
h11==0.14.0
idna==3.10
pycparser==2.22
pydantic==2.9.1
pydantic-settings==2.5.2
pydantic_core==2.23.3
python-dotenv==1.0.1
PyYAML==6.0.2
requests==2.32.3
sniffio==1.3.1
starlette==0.38.5
typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.30.6
1 change: 0 additions & 1 deletion secrets/.sample

This file was deleted.

29 changes: 27 additions & 2 deletions vaultapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import uvicorn
from cryptography.fernet import Fernet
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from . import models, routers, squire, version
from . import models, routes, squire, version

LOGGER = logging.getLogger("uvicorn.default")
VaultAPI = FastAPI(
Expand All @@ -15,6 +16,29 @@
)


def enable_cors() -> None:
"""Enables CORS policy."""
LOGGER.info("Setting CORS policy")
origins = [
"http://localhost.com",
"https://localhost.com",
]
for website in models.env.endpoints:
origins.extend([f"http://{website.host}", f"https://{website.host}"]) # noqa: HttpUrlsUsage
VaultAPI.add_middleware(
CORSMiddleware, # noqa: PyTypeChecker
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=[
# Default headers
"host",
"user-agent",
"authorization",
],
)


def start(**kwargs) -> None:
"""Starter function for the API, which uses uvicorn server as trigger.
Expand All @@ -34,7 +58,8 @@ def start(**kwargs) -> None:
models.database = models.Database(models.env.database)
models.database.create_table("default", ["key", "value"])
module_name = pathlib.Path(__file__)
VaultAPI.routes.extend(routers.get_all_routes())
enable_cors()
VaultAPI.routes.extend(routes.get_all_routes())
kwargs = dict(
host=models.env.host,
port=models.env.port,
Expand Down
36 changes: 13 additions & 23 deletions vaultapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Set, Tuple

from cryptography.fernet import Fernet
from pydantic import BaseModel, Field, FilePath, NewPath, PositiveInt, field_validator
from pydantic import BaseModel, Field, FilePath, NewPath, PositiveInt, field_validator, HttpUrl
from pydantic_settings import BaseSettings


Expand Down Expand Up @@ -132,44 +132,34 @@ class EnvConfig(BaseSettings):
apikey: str
secret: str
database: FilePath | NewPath | str = Field("secrets.db", pattern=".*.db$")
endpoints: HttpUrl | List[HttpUrl] = []
host: str = socket.gethostbyname("localhost") or "0.0.0.0"
port: PositiveInt = 8080
workers: PositiveInt = 1
log_config: FilePath | Dict[str, Any] | None = None
allowed_origins: List[str] = []
rate_limit: RateLimit | List[RateLimit] = []

# noinspection PyMethodParameters
@field_validator("apikey", mode="after")
def parse_apikey(cls, value: str | None) -> str | None:
"""Parse API key to validate complexity.
Args:
value: Takes the user input as an argument.
@field_validator("endpoints", mode="after", check_fields=True)
def parse_endpoints(cls, value: HttpUrl | List[HttpUrl]) -> List[HttpUrl]: # noqa: PyMethodParameters
"""Validate endpoints to enable CORS policy."""
if isinstance(value, list):
return value
return [value]

Returns:
str:
Returns the parsed value.
"""
@field_validator("apikey", mode="after")
def parse_apikey(cls, value: str | None) -> str | None: # noqa: PyMethodParameters
"""Parse API key to validate complexity."""
if value:
try:
complexity_checker(value, True)
except AssertionError as error:
raise ValueError(error.__str__())
return value

# noinspection PyMethodParameters
@field_validator("secret", mode="after")
def parse_api_secret(cls, value: str | None) -> str | None:
"""Parse API secret to validate complexity.
Args:
value: Takes the user input as an argument.
Returns:
str:
Returns the parsed value.
"""
def parse_api_secret(cls, value: str | None) -> str | None: # noqa: PyMethodParameters
"""Parse API secret to validate complexity."""
if value:
try:
complexity_checker(value)
Expand Down
12 changes: 7 additions & 5 deletions vaultapi/routers.py → vaultapi/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def get_secret(

async def get_secrets(
request: Request,
keys: List[str],
keys: str,
table_name: str = "default",
apikey: HTTPAuthorizationCredentials = Depends(security),
):
Expand All @@ -103,7 +103,7 @@ async def get_secrets(
**Args:**
request: Reference to the FastAPI request object.
key: List of secret names to be retrieved.
key: Comma separated list of secret names to be retrieved.
table_name: Name of the table where the secrets are stored.
apikey: API Key to authenticate the request.
Expand All @@ -113,6 +113,8 @@ async def get_secrets(
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
# keys = [key.strip() for key in keys.split(",") if key.strip()]
keys = list(filter(None, map(str.strip, keys.split(","))))
keys_ct = len(keys)
try:
assert keys_ct >= 1, f"Expected at least one key, received {keys_ct}"
Expand Down Expand Up @@ -141,7 +143,7 @@ async def get_secrets(
LOGGER.info("Secret value for '%s' NOT found in the datastore", keys[0])
else:
LOGGER.info(
"Secret values for %d keys (%s) were NOT found in the datastore",
"Secret values for %d keys %s were NOT found in the datastore",
keys_ct,
keys,
)
Expand Down Expand Up @@ -319,7 +321,7 @@ def get_all_routes() -> List[APIRoute]:
APIRoute(
path="/get-secrets",
endpoint=get_secrets,
methods=["POST"],
methods=["GET"],
dependencies=dependencies,
),
APIRoute(
Expand All @@ -331,7 +333,7 @@ def get_all_routes() -> List[APIRoute]:
APIRoute(
path="/put-secret",
endpoint=put_secret,
methods=["POST"],
methods=["PUT"],
dependencies=dependencies,
),
APIRoute(
Expand Down
15 changes: 15 additions & 0 deletions vaultapi/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import base64
from typing import ByteString


def encode_secret(key: str) -> ByteString:
"""Encodes a key into URL safe string.
Args:
key: Key to be encoded.
Returns:
ByteString:
Returns an encoded URL safe string.
"""
return base64.urlsafe_b64encode(key.encode(encoding="UTF-8"))

0 comments on commit 0cdcd3b

Please sign in to comment.