diff --git a/pydantic_settings_aws/__init__.py b/pydantic_settings_aws/__init__.py index 61ad266..48af907 100644 --- a/pydantic_settings_aws/__init__.py +++ b/pydantic_settings_aws/__init__.py @@ -1,6 +1,6 @@ -from .settings import SecretsManagerBaseSettings +from .settings import ParameterStoreBaseSettings, SecretsManagerBaseSettings from .version import VERSION -__all__ = [ "SecretsManagerBaseSettings" ] +__all__ = ["ParameterStoreBaseSettings", "SecretsManagerBaseSettings"] __version__ = VERSION diff --git a/pydantic_settings_aws/aws.py b/pydantic_settings_aws/aws.py index be94d4e..85ef901 100644 --- a/pydantic_settings_aws/aws.py +++ b/pydantic_settings_aws/aws.py @@ -1,9 +1,10 @@ import json -from typing import Any, Optional +from typing import Any, AnyStr, Dict, Optional, Union import boto3 from mypy_boto3_secretsmanager import SecretsManagerClient from mypy_boto3_secretsmanager.type_defs import GetSecretValueResponseTypeDef +from mypy_boto3_ssm import SSMClient from pydantic import ValidationError from pydantic_settings import BaseSettings @@ -11,8 +12,42 @@ from .models import AwsSecretsArgs, AwsSession +def get_ssm_content( + settings: type[BaseSettings], + field_name: str, + ssm_info: Optional[Union[Dict[Any, AnyStr], AnyStr]] = None +) -> Optional[str]: + client = None + ssm_name = field_name + + if isinstance(ssm_info, str): + logger.debug("Parameter name specified as a str") + ssm_name = ssm_info + + elif isinstance(ssm_info, dict): + logger.debug("Parameter specified as a dict") + ssm_name = str(ssm_info["ssm"]) + + logger.debug("Checking for a especific boto3 client for the Parameter") + client = ssm_info.get("ssm_client", None) + + else: + logger.debug("Will try to find a parameter with the parameter name") + + if not client: + logger.debug("Boto3 client not specified in metadata") + client = _get_ssm_boto3_client(settings) # type: ignore + + logger.debug(f"Getting parameter {ssm_name} value with boto3 client") + ssm_response: dict[str, Any] = client.get_parameter( # type: ignore + Name=ssm_name, WithDecryption=True + ) + + return ssm_response.get("Parameter", {}).get("Value", None) + + def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]: - client: SecretsManagerClient = _get_boto3_client(settings) + client: SecretsManagerClient = _get_secrets_boto3_client(settings) secrets_args: AwsSecretsArgs = _get_secrets_args(settings) logger.debug("Getting secrets manager value with boto3 client") @@ -37,7 +72,9 @@ def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]: raise json_err -def _get_boto3_client(settings: type[BaseSettings]) -> SecretsManagerClient: +def _get_secrets_boto3_client( + settings: type[BaseSettings], +) -> SecretsManagerClient: logger.debug("Getting secrets manager content.") client: SecretsManagerClient | None = settings.model_config.get( # type: ignore "secrets_client", None @@ -122,3 +159,45 @@ def _get_secrets_content( raise err return secrets_content + + +def _get_ssm_boto3_client(settings: type[BaseSettings]) -> SSMClient: + logger.debug("Getting secrets manager content.") + client: SSMClient | None = settings.model_config.get( # type: ignore + "ssm_client", None + ) + + if client: + return client + + logger.debug( + "No ssm boto3 client was informed. Will try to create a new one" + ) + return _create_ssm_client(settings) + + +def _create_ssm_client( + settings: type[BaseSettings], +) -> SSMClient: + """Create a boto3 client for parameter store. + + Neither `boto3` nor `pydantic` exceptions will be handled. + + Args: + settings (BaseSettings): Settings from `pydantic_settings` + + Returns: + SSMClient: A parameter ssm boto3 client. + """ + logger.debug("Extracting settings prefixed with aws_") + args: dict[str, Any] = { + k: v for k, v in settings.model_config.items() if k.startswith("aws_") + } + + session_args = AwsSession(**args) + + session: boto3.Session = boto3.Session( + **session_args.model_dump(by_alias=True, exclude_none=True) + ) + + return session.client("ssm") diff --git a/pydantic_settings_aws/settings.py b/pydantic_settings_aws/settings.py index f21d306..3c063e4 100644 --- a/pydantic_settings_aws/settings.py +++ b/pydantic_settings_aws/settings.py @@ -3,11 +3,29 @@ PydanticBaseSettingsSource, ) -from .sources import SecretsManagerSettingsSource +from .sources import ParameterStoreSettingsSource, SecretsManagerSettingsSource -class SecretsManagerBaseSettings(BaseSettings): +class ParameterStoreBaseSettings(BaseSettings): + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + ParameterStoreSettingsSource(settings_cls), + env_settings, + dotenv_settings, + file_secret_settings, + ) + +class SecretsManagerBaseSettings(BaseSettings): @classmethod def settings_customise_sources( cls, diff --git a/pydantic_settings_aws/sources.py b/pydantic_settings_aws/sources.py index 78edff0..fe340da 100644 --- a/pydantic_settings_aws/sources.py +++ b/pydantic_settings_aws/sources.py @@ -6,7 +6,46 @@ PydanticBaseSettingsSource, ) -from pydantic_settings_aws import aws +from pydantic_settings_aws import aws, utils + + +class ParameterStoreSettingsSource(PydanticBaseSettingsSource): + """Source class for loading settings from AWS Parameter Store. + """ + def __init__(self, settings_cls: type[BaseSettings]): + super().__init__(settings_cls) + + def get_field_value( + self, field: FieldInfo, field_name: str + ) -> tuple[Any, str, bool]: + ssm_info = utils.get_ssm_name_from_annotated_field(field.metadata) + field_value = aws.get_ssm_content(self.settings_cls, field_name, ssm_info) + + return field_value, field_name, False + + def prepare_field_value( + self, + field_name: str, + field: FieldInfo, + value: Any, + value_is_complex: bool, + ) -> Any: + return value + + def __call__(self) -> dict[str, Any]: + d: dict[str, Any] = {} + + for field_name, field in self.settings_cls.model_fields.items(): + field_value, field_key, value_is_complex = self.get_field_value( + field, field_name + ) + field_value = self.prepare_field_value( + field_name, field, field_value, value_is_complex + ) + if field_value is not None: + d[field_key] = field_value + + return d class SecretsManagerSettingsSource(PydanticBaseSettingsSource): diff --git a/pydantic_settings_aws/utils.py b/pydantic_settings_aws/utils.py new file mode 100644 index 0000000..e046dab --- /dev/null +++ b/pydantic_settings_aws/utils.py @@ -0,0 +1,22 @@ +from typing import Any, List, Optional + + +def get_ssm_name_from_annotated_field(metadata: List[Any]) -> Optional[str]: + ssm_metadata = list( + filter(_get_ssm_info_from_metadata, metadata) + ) + + if ssm_metadata: + return ssm_metadata[0] + + return None + + +def _get_ssm_info_from_metadata(metadata: Any) -> Optional[Any]: + if isinstance(metadata, str): + return metadata + + if isinstance(metadata, dict) and "ssm" in metadata.keys(): + return metadata + + return None diff --git a/pydantic_settings_aws/version.py b/pydantic_settings_aws/version.py index 901e511..bc4ffb3 100644 --- a/pydantic_settings_aws/version.py +++ b/pydantic_settings_aws/version.py @@ -1 +1 @@ -VERSION = "0.0.1" +VERSION = "0.0.2" diff --git a/pyproject.toml b/pyproject.toml index e017700..3725581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ 'pydantic-settings>=2.0.2', 'boto3>=1.27.0', 'boto3-stubs[secretsmanager]>=1.27.0', + 'boto3-stubs[ssm]>=1.27.0' ] dynamic = ['version'] @@ -72,7 +73,7 @@ source = ['pydantic_settings_aws/'] [tool.ruff] line-length = 80 -target-version = 'py39' +target-version = 'py38' [tool.ruff.lint] extend-select = ['Q', 'RUF100', 'C90', 'UP', 'I'] diff --git a/tests/aws_mocks.py b/tests/aws_mocks.py index 0cc7cdf..c88dd0a 100644 --- a/tests/aws_mocks.py +++ b/tests/aws_mocks.py @@ -4,20 +4,29 @@ TARGET_SESSION = "pydantic_settings_aws.aws.boto3.Session" -TARGET_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_boto3_client" +TARGET_SECRETS_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_secrets_boto3_client" + +TARGET_SSM_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_ssm_boto3_client" TARGET_SECRETS_CLIENT = "pydantic_settings_aws.aws._create_secrets_client" +TARGET_SSM_CLIENT = "pydantic_settings_aws.aws._create_ssm_client" + TARGET_SECRET_CONTENT = "pydantic_settings_aws.aws._get_secrets_content" def mock_secrets_content_invalid_json(*args): return ClientMock(secret_string="invalid-json") + def mock_secrets_content_empty(*args): return ClientMock(secret_string=None) +def mock_ssm(*args): + return ClientMock(ssm_value="value") + + def mock_create_client(*args): return object() diff --git a/tests/aws_test.py b/tests/aws_test.py index bd908e0..39aabb0 100644 --- a/tests/aws_test.py +++ b/tests/aws_test.py @@ -7,18 +7,78 @@ from pydantic_settings_aws import aws from .aws_mocks import ( - TARGET_BOTO3_CLIENT, TARGET_SECRET_CONTENT, + TARGET_SECRETS_BOTO3_CLIENT, TARGET_SECRETS_CLIENT, TARGET_SESSION, + TARGET_SSM_CLIENT, BaseSettingsMock, mock_create_client, mock_secrets_content_empty, mock_secrets_content_invalid_json, + mock_ssm, ) from .boto3_mocks import SessionMock +@mock.patch(TARGET_SSM_CLIENT, mock_ssm) +def test_get_ssm_content_must_return_parameter_content_if_annotated_with_parameter_name(*args): + settings = BaseSettingsMock() + settings.model_config = {"aws_region": "region", "aws_profile": "profile"} + parameter_value = aws.get_ssm_content(settings, "field", "my/parameter/name") + + assert parameter_value is not None + assert isinstance(parameter_value, str) + + +@mock.patch(TARGET_SSM_CLIENT, mock_ssm) +def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_args(*args): + settings = BaseSettingsMock() + settings.model_config = {"aws_region": "region", "aws_profile": "profile"} + parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name"}) + + assert parameter_value is not None + assert isinstance(parameter_value, str) + + +@mock.patch(TARGET_SSM_CLIENT, mock_ssm) +def test_get_ssm_content_must_use_client_if_present_in_metadata(*args): + settings = BaseSettingsMock() + settings.model_config = {"aws_region": "region", "aws_profile": "profile"} + parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name", "ssm_client": mock_ssm()}) + + assert parameter_value is not None + assert isinstance(parameter_value, str) + + +@mock.patch(TARGET_SSM_CLIENT, mock_ssm) +def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args): + settings = BaseSettingsMock() + settings.model_config = {"aws_region": "region", "aws_profile": "profile"} + parameter_value = aws.get_ssm_content(settings, "field", None) + + assert parameter_value is not None + assert isinstance(parameter_value, str) + + +@mock.patch(TARGET_SESSION, SessionMock) +def test_create_ssm_client(*args): + settings = BaseSettingsMock() + settings.model_config = {"aws_region": "region", "aws_profile": "profile"} + client = aws._create_ssm_client(settings) + + assert client is not None + + +@mock.patch(TARGET_SSM_CLIENT, mock_create_client) +def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args): + settings = BaseSettingsMock() + settings.model_config = {} + client = aws._get_ssm_boto3_client(settings) + + assert client is not None + + @mock.patch(TARGET_SESSION, SessionMock) def test_create_secrets_client(*args): settings = BaseSettingsMock() @@ -29,15 +89,15 @@ def test_create_secrets_client(*args): @mock.patch(TARGET_SECRETS_CLIENT, mock_create_client) -def test_get_boto3_client_must_create_a_client_if_its_not_given(*args): +def test_get_secrets_boto3_client_must_create_a_client_if_its_not_given(*args): settings = BaseSettingsMock() settings.model_config = {} - client = aws._get_boto3_client(settings) + client = aws._get_secrets_boto3_client(settings) assert client is not None -@mock.patch(TARGET_BOTO3_CLIENT, mock_secrets_content_empty) +@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_empty) @mock.patch(TARGET_SECRET_CONTENT, lambda *args: None) def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none( *args, @@ -53,7 +113,7 @@ def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none( aws.get_secrets_content(settings) -@mock.patch(TARGET_BOTO3_CLIENT, mock_secrets_content_invalid_json) +@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_invalid_json) def test_should_not_obfuscate_json_error_in_case_of_invalid_secrets(*args): settings = BaseSettingsMock() settings.model_config = { diff --git a/tests/boto3_mocks.py b/tests/boto3_mocks.py index 7612a71..935d623 100644 --- a/tests/boto3_mocks.py +++ b/tests/boto3_mocks.py @@ -13,10 +13,19 @@ class ClientMock: def __init__( self, secret_string: str = None, - secret_bytes: bytes = None + secret_bytes: bytes = None, + ssm_value: str = None ) -> None: self.secret_string = secret_string self.secret_bytes = secret_bytes + self.ssm_value = ssm_value + + def get_parameter(self, Name=None, WithDecryption=None): + return { + "Parameter": { + "Value": self.ssm_value + } + } def get_secret_value( self, SecretId=None, VersionId=None, VersionStage=None diff --git a/tests/settings_mocks.py b/tests/settings_mocks.py index 4ac7862..ea5038c 100644 --- a/tests/settings_mocks.py +++ b/tests/settings_mocks.py @@ -1,10 +1,13 @@ import json -from typing import Optional +from typing import Annotated, Optional from pydantic import BaseModel from pydantic_settings import SettingsConfigDict -from pydantic_settings_aws import SecretsManagerBaseSettings +from pydantic_settings_aws import ( + ParameterStoreBaseSettings, + SecretsManagerBaseSettings, +) from .boto3_mocks import ClientMock @@ -54,3 +57,24 @@ class SecretsWithNestedContent(SecretsManagerBaseSettings): username: str password: str nested: NestedContent + + +class ParameterSettings(ParameterStoreBaseSettings): + my_ssm: Annotated[str, {"ssm": "my/parameter", "ssm_client": ClientMock(ssm_value="value")}] + + +class ParameterWithTwoSSMClientSettings(ParameterStoreBaseSettings): + model_config = SettingsConfigDict( + ssm_client=ClientMock(ssm_value="value") + ) + + my_ssm: Annotated[str, {"ssm": "my/parameter", "ssm_client": ClientMock(ssm_value="value")}] + my_ssm_2: Annotated[str, "my/ssm/2/parameter"] + + +class ParameterWithOptionalValueSettings(ParameterStoreBaseSettings): + model_config = SettingsConfigDict( + ssm_client=ClientMock() + ) + + my_ssm: Annotated[Optional[str], "my/ssm/2/parameter"] = None diff --git a/tests/settings_test.py b/tests/settings_test.py index 755141c..21eb7f3 100644 --- a/tests/settings_test.py +++ b/tests/settings_test.py @@ -1,5 +1,8 @@ from .settings_mocks import ( MySecretsWithClientConfig, + ParameterSettings, + ParameterWithOptionalValueSettings, + ParameterWithTwoSSMClientSettings, SecretsWithNestedContent, ) @@ -18,3 +21,29 @@ def test_secrets_settings_with_nested_secrets_content(): assert my_config.username == "myusername" assert my_config.nested is not None assert len(my_config.nested.roles) > 0 + + +def test_ssm_with_annotated_str(): + my_config = ParameterSettings() + + assert my_config is not None + assert my_config.my_ssm is not None + assert isinstance(my_config.my_ssm, str) + + +def test_ssm_with_and_without_ssm_client(): + my_config = ParameterWithTwoSSMClientSettings() + + assert my_config is not None + assert my_config.my_ssm is not None + assert isinstance(my_config.my_ssm, str) + + assert my_config.my_ssm_2 is not None + assert isinstance(my_config.my_ssm_2, str) + + +def test_ssm_with_none_in_optional_values(): + my_config = ParameterWithOptionalValueSettings() + + assert my_config is not None + assert my_config.my_ssm is None diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 0000000..f9a5ea5 --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,20 @@ +from pydantic_settings_aws import utils + + +def test_get_annotated_ssm(*args): + metadata = [{"ssm": "my/ssm"}] + a = utils.get_ssm_name_from_annotated_field(metadata) + + assert a is not None + assert a["ssm"] == "my/ssm" + + metadata = ["my/ssm"] + a = utils.get_ssm_name_from_annotated_field(metadata) + + assert a is not None + assert a == "my/ssm" + + metadata = [str] + a = utils.get_ssm_name_from_annotated_field(metadata) + + assert a is None