Skip to content

Commit

Permalink
Merge pull request #8 from ceb10n/features/parameter-store
Browse files Browse the repository at this point in the history
Add Parameter Store Settings
  • Loading branch information
ceb10n authored Jul 20, 2024
2 parents d7a8d60 + 8d876e9 commit 2e308e8
Show file tree
Hide file tree
Showing 13 changed files with 329 additions and 19 deletions.
4 changes: 2 additions & 2 deletions pydantic_settings_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .settings import SecretsManagerBaseSettings
from .settings import ParameterStoreBaseSettings, SecretsManagerBaseSettings
from .version import VERSION

__all__ = [ "SecretsManagerBaseSettings" ]
__all__ = ["ParameterStoreBaseSettings", "SecretsManagerBaseSettings"]

__version__ = VERSION
85 changes: 82 additions & 3 deletions pydantic_settings_aws/aws.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,53 @@
import json
from typing import Any, Optional
from typing import Any, AnyStr, Dict, Optional, Union

import boto3
from mypy_boto3_secretsmanager import SecretsManagerClient
from mypy_boto3_secretsmanager.type_defs import GetSecretValueResponseTypeDef
from mypy_boto3_ssm import SSMClient
from pydantic import ValidationError
from pydantic_settings import BaseSettings

from .logger import logger
from .models import AwsSecretsArgs, AwsSession


def get_ssm_content(
settings: type[BaseSettings],
field_name: str,
ssm_info: Optional[Union[Dict[Any, AnyStr], AnyStr]] = None
) -> Optional[str]:
client = None
ssm_name = field_name

if isinstance(ssm_info, str):
logger.debug("Parameter name specified as a str")
ssm_name = ssm_info

elif isinstance(ssm_info, dict):
logger.debug("Parameter specified as a dict")
ssm_name = str(ssm_info["ssm"])

logger.debug("Checking for a especific boto3 client for the Parameter")
client = ssm_info.get("ssm_client", None)

else:
logger.debug("Will try to find a parameter with the parameter name")

if not client:
logger.debug("Boto3 client not specified in metadata")
client = _get_ssm_boto3_client(settings) # type: ignore

logger.debug(f"Getting parameter {ssm_name} value with boto3 client")
ssm_response: dict[str, Any] = client.get_parameter( # type: ignore
Name=ssm_name, WithDecryption=True
)

return ssm_response.get("Parameter", {}).get("Value", None)


def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]:
client: SecretsManagerClient = _get_boto3_client(settings)
client: SecretsManagerClient = _get_secrets_boto3_client(settings)
secrets_args: AwsSecretsArgs = _get_secrets_args(settings)

logger.debug("Getting secrets manager value with boto3 client")
Expand All @@ -37,7 +72,9 @@ def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]:
raise json_err


def _get_boto3_client(settings: type[BaseSettings]) -> SecretsManagerClient:
def _get_secrets_boto3_client(
settings: type[BaseSettings],
) -> SecretsManagerClient:
logger.debug("Getting secrets manager content.")
client: SecretsManagerClient | None = settings.model_config.get( # type: ignore
"secrets_client", None
Expand Down Expand Up @@ -122,3 +159,45 @@ def _get_secrets_content(
raise err

return secrets_content


def _get_ssm_boto3_client(settings: type[BaseSettings]) -> SSMClient:
logger.debug("Getting secrets manager content.")
client: SSMClient | None = settings.model_config.get( # type: ignore
"ssm_client", None
)

if client:
return client

logger.debug(
"No ssm boto3 client was informed. Will try to create a new one"
)
return _create_ssm_client(settings)


def _create_ssm_client(
settings: type[BaseSettings],
) -> SSMClient:
"""Create a boto3 client for parameter store.
Neither `boto3` nor `pydantic` exceptions will be handled.
Args:
settings (BaseSettings): Settings from `pydantic_settings`
Returns:
SSMClient: A parameter ssm boto3 client.
"""
logger.debug("Extracting settings prefixed with aws_")
args: dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

session_args = AwsSession(**args)

session: boto3.Session = boto3.Session(
**session_args.model_dump(by_alias=True, exclude_none=True)
)

return session.client("ssm")
22 changes: 20 additions & 2 deletions pydantic_settings_aws/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,29 @@
PydanticBaseSettingsSource,
)

from .sources import SecretsManagerSettingsSource
from .sources import ParameterStoreSettingsSource, SecretsManagerSettingsSource


class SecretsManagerBaseSettings(BaseSettings):
class ParameterStoreBaseSettings(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
ParameterStoreSettingsSource(settings_cls),
env_settings,
dotenv_settings,
file_secret_settings,
)


class SecretsManagerBaseSettings(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
Expand Down
41 changes: 40 additions & 1 deletion pydantic_settings_aws/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,46 @@
PydanticBaseSettingsSource,
)

from pydantic_settings_aws import aws
from pydantic_settings_aws import aws, utils


class ParameterStoreSettingsSource(PydanticBaseSettingsSource):
"""Source class for loading settings from AWS Parameter Store.
"""
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)

def get_field_value(
self, field: FieldInfo, field_name: str
) -> tuple[Any, str, bool]:
ssm_info = utils.get_ssm_name_from_annotated_field(field.metadata)
field_value = aws.get_ssm_content(self.settings_cls, field_name, ssm_info)

return field_value, field_name, False

def prepare_field_value(
self,
field_name: str,
field: FieldInfo,
value: Any,
value_is_complex: bool,
) -> Any:
return value

def __call__(self) -> dict[str, Any]:
d: dict[str, Any] = {}

for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = self.get_field_value(
field, field_name
)
field_value = self.prepare_field_value(
field_name, field, field_value, value_is_complex
)
if field_value is not None:
d[field_key] = field_value

return d


class SecretsManagerSettingsSource(PydanticBaseSettingsSource):
Expand Down
22 changes: 22 additions & 0 deletions pydantic_settings_aws/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, List, Optional


def get_ssm_name_from_annotated_field(metadata: List[Any]) -> Optional[str]:
ssm_metadata = list(
filter(_get_ssm_info_from_metadata, metadata)
)

if ssm_metadata:
return ssm_metadata[0]

return None


def _get_ssm_info_from_metadata(metadata: Any) -> Optional[Any]:
if isinstance(metadata, str):
return metadata

if isinstance(metadata, dict) and "ssm" in metadata.keys():
return metadata

return None
2 changes: 1 addition & 1 deletion pydantic_settings_aws/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.0.1"
VERSION = "0.0.2"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
'pydantic-settings>=2.0.2',
'boto3>=1.27.0',
'boto3-stubs[secretsmanager]>=1.27.0',
'boto3-stubs[ssm]>=1.27.0'
]
dynamic = ['version']

Expand Down Expand Up @@ -68,7 +69,7 @@ source = ['pydantic_settings_aws/']

[tool.ruff]
line-length = 80
target-version = 'py39'
target-version = 'py38'

[tool.ruff.lint]
extend-select = ['Q', 'RUF100', 'C90', 'UP', 'I']
Expand Down
11 changes: 10 additions & 1 deletion tests/aws_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,29 @@

TARGET_SESSION = "pydantic_settings_aws.aws.boto3.Session"

TARGET_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_boto3_client"
TARGET_SECRETS_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_secrets_boto3_client"

TARGET_SSM_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_ssm_boto3_client"

TARGET_SECRETS_CLIENT = "pydantic_settings_aws.aws._create_secrets_client"

TARGET_SSM_CLIENT = "pydantic_settings_aws.aws._create_ssm_client"

TARGET_SECRET_CONTENT = "pydantic_settings_aws.aws._get_secrets_content"


def mock_secrets_content_invalid_json(*args):
return ClientMock(secret_string="invalid-json")


def mock_secrets_content_empty(*args):
return ClientMock(secret_string=None)


def mock_ssm(*args):
return ClientMock(ssm_value="value")


def mock_create_client(*args):
return object()

Expand Down
70 changes: 65 additions & 5 deletions tests/aws_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,78 @@
from pydantic_settings_aws import aws

from .aws_mocks import (
TARGET_BOTO3_CLIENT,
TARGET_SECRET_CONTENT,
TARGET_SECRETS_BOTO3_CLIENT,
TARGET_SECRETS_CLIENT,
TARGET_SESSION,
TARGET_SSM_CLIENT,
BaseSettingsMock,
mock_create_client,
mock_secrets_content_empty,
mock_secrets_content_invalid_json,
mock_ssm,
)
from .boto3_mocks import SessionMock


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_parameter_name(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", "my/parameter/name")

assert parameter_value is not None
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_args(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name"})

assert parameter_value is not None
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
def test_get_ssm_content_must_use_client_if_present_in_metadata(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name", "ssm_client": mock_ssm()})

assert parameter_value is not None
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", None)

assert parameter_value is not None
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SESSION, SessionMock)
def test_create_ssm_client(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
client = aws._create_ssm_client(settings)

assert client is not None


@mock.patch(TARGET_SSM_CLIENT, mock_create_client)
def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args):
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._get_ssm_boto3_client(settings)

assert client is not None


@mock.patch(TARGET_SESSION, SessionMock)
def test_create_secrets_client(*args):
settings = BaseSettingsMock()
Expand All @@ -29,15 +89,15 @@ def test_create_secrets_client(*args):


@mock.patch(TARGET_SECRETS_CLIENT, mock_create_client)
def test_get_boto3_client_must_create_a_client_if_its_not_given(*args):
def test_get_secrets_boto3_client_must_create_a_client_if_its_not_given(*args):
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._get_boto3_client(settings)
client = aws._get_secrets_boto3_client(settings)

assert client is not None


@mock.patch(TARGET_BOTO3_CLIENT, mock_secrets_content_empty)
@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_empty)
@mock.patch(TARGET_SECRET_CONTENT, lambda *args: None)
def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(
*args,
Expand All @@ -53,7 +113,7 @@ def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(
aws.get_secrets_content(settings)


@mock.patch(TARGET_BOTO3_CLIENT, mock_secrets_content_invalid_json)
@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_invalid_json)
def test_should_not_obfuscate_json_error_in_case_of_invalid_secrets(*args):
settings = BaseSettingsMock()
settings.model_config = {
Expand Down
Loading

0 comments on commit 2e308e8

Please sign in to comment.