From c10f673fa6da55499de1fa63c4416ccb8d8bbce2 Mon Sep 17 00:00:00 2001 From: Vignesh Rao Date: Tue, 17 Sep 2024 07:29:12 -0500 Subject: [PATCH] Retrieve multiple secrets at a time or an entire table --- vaultapi/database.py | 18 ++++++ vaultapi/routers.py | 136 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 148 insertions(+), 6 deletions(-) diff --git a/vaultapi/database.py b/vaultapi/database.py index e2d7eaf..84f0b67 100644 --- a/vaultapi/database.py +++ b/vaultapi/database.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + from . import models @@ -21,6 +23,22 @@ def get_secret(key: str, table_name: str) -> str | None: return state[0] +def get_table(table_name: str) -> List[Tuple[str, str]]: + """Function to retrieve all key-value pairs from a particular table in the database. + + Args: + table_name: Name of the table where the secrets are stored. + + Returns: + str: + Returns the secret value. + """ + with models.database.connection: + cursor = models.database.connection.cursor() + state = cursor.execute(f'SELECT * FROM "{table_name}"').fetchall() + return state + + def put_secret(key: str, value: str, table_name: str) -> None: """Function to add secret to the database. diff --git a/vaultapi/routers.py b/vaultapi/routers.py index 08d4992..6578bc8 100644 --- a/vaultapi/routers.py +++ b/vaultapi/routers.py @@ -14,8 +14,8 @@ security = HTTPBearer() -async def retrieve_existing(key: str, table_name: str) -> str | None: - """Retrieve existing secret from database. +async def retrieve_secret(key: str, table_name: str) -> str | None: + """Retrieve an existing secret from a table in the database. Args: key: Name of the secret to retrieve. @@ -34,13 +34,40 @@ async def retrieve_existing(key: str, table_name: str) -> str | None: ) +async def retrieve_secrets(table_name: str, keys: List[str] = None) -> Dict[str, str]: + """Retrieve multiple secrets from a table or retrieve the table as a whole. + + Args: + table_name: Name of the table where the secret is stored. + keys: List of keys for which the values have to be retrieved. + + Returns: + Dict[str, str]: + Returns the key-value pairs for secret key and it's value. + """ + if keys: + values = {} + for key in keys: + if value := await retrieve_secret(key, table_name): + values[key] = value + return values + else: + try: + return dict(database.get_table(table_name)) + except sqlite3.OperationalError as error: + LOGGER.error(error) + raise exceptions.APIResponse( + status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0] + ) + + async def get_secret( request: Request, key: str, table_name: str = "default", apikey: HTTPAuthorizationCredentials = Depends(security), ): - """**API function to retrieve secrets.** + """**API function to retrieve a secret.** **Args:** @@ -55,7 +82,7 @@ async def get_secret( Raises the HTTPStatus object with a status code and detail as response. """ await auth.validate(request, apikey) - if value := await retrieve_existing(key, table_name): + if value := await retrieve_secret(key, table_name): LOGGER.info("Secret value for '%s' was retrieved", key) decrypted = models.session.fernet.decrypt(value).decode(encoding="UTF-8") raise exceptions.APIResponse(status_code=HTTPStatus.OK.real, detail=decrypted) @@ -65,6 +92,91 @@ async def get_secret( ) +async def get_secrets( + request: Request, + keys: List[str], + table_name: str = "default", + apikey: HTTPAuthorizationCredentials = Depends(security), +): + """**API function to retrieve multiple secrets at a time.** + + **Args:** + + request: Reference to the FastAPI request object. + key: List of secret names to be retrieved. + table_name: Name of the table where the secrets are stored. + apikey: API Key to authenticate the request. + + **Raises:** + + APIResponse: + Raises the HTTPStatus object with a status code and detail as response. + """ + await auth.validate(request, apikey) + keys_ct = len(keys) + try: + assert keys_ct >= 1, f"Expected at least one key, received {keys_ct}" + except AssertionError as error: + LOGGER.error(error) + raise exceptions.APIResponse( + status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0] + ) + if values := await retrieve_secrets(table_name, keys): + values_ct = len(values) + try: + assert ( + values_ct == keys_ct + ), f"Number of keys [{keys_ct}] requested didn't match the number of values [{values_ct}] retrieved." + LOGGER.info("Secret value for %d (%s) were retrieved", keys_ct, keys) + code = HTTPStatus.OK.real + except AssertionError as error: + LOGGER.warning(error) + code = HTTPStatus.PARTIAL_CONTENT.real + decrypted = { + key: models.session.fernet.decrypt(value).decode(encoding="UTF-8") + for key, value in values.items() + } + raise exceptions.APIResponse(status_code=code, detail=decrypted) + if keys_ct == 1: + LOGGER.info("Secret value for '%s' NOT found in the datastore", keys[0]) + else: + LOGGER.info( + "Secret values for %d keys (%s) were NOT found in the datastore", + keys_ct, + keys, + ) + raise exceptions.APIResponse( + status_code=HTTPStatus.NOT_FOUND.real, detail=HTTPStatus.NOT_FOUND.phrase + ) + + +async def get_table( + request: Request, + table_name: str = "default", + apikey: HTTPAuthorizationCredentials = Depends(security), +): + """**API function to retrieve ALL the key-value pairs stored in a particular table.** + + **Args:** + + request: Reference to the FastAPI request object. + table_name: Name of the table where the secrets are stored. + apikey: API Key to authenticate the request. + + **Raises:** + + APIResponse: + Raises the HTTPStatus object with a status code and detail as response. + """ + await auth.validate(request, apikey) + table_content = await retrieve_secrets(table_name) + decrypted = { + key: models.session.fernet.decrypt(value).decode(encoding="UTF-8") + for key, value in table_content.items() + } + raise exceptions.APIResponse(status_code=HTTPStatus.OK.real, detail=decrypted) + + async def put_secret( request: Request, data: payload.PutSecret, @@ -84,7 +196,7 @@ async def put_secret( Raises the HTTPStatus object with a status code and detail as response. """ await auth.validate(request, apikey) - if await retrieve_existing(data.key, data.table_name): + if await retrieve_secret(data.key, data.table_name): LOGGER.info("Secret value for '%s' will be overridden", data.key) else: LOGGER.info( @@ -118,7 +230,7 @@ async def delete_secret( Raises the HTTPStatus object with a status code and detail as response. """ await auth.validate(request, apikey) - if await retrieve_existing(data.key, data.table_name): + if await retrieve_secret(data.key, data.table_name): LOGGER.info("Secret value for '%s' will be removed", data.key) else: LOGGER.warning("Secret value for '%s' NOT found", data.key) @@ -204,6 +316,18 @@ def get_all_routes() -> List[APIRoute]: methods=["GET"], dependencies=dependencies, ), + APIRoute( + path="/get-secrets", + endpoint=get_secrets, + methods=["POST"], + dependencies=dependencies, + ), + APIRoute( + path="/get-table", + endpoint=get_table, + methods=["GET"], + dependencies=dependencies, + ), APIRoute( path="/put-secret", endpoint=put_secret,