Skip to content

Commit

Permalink
Merge pull request #25 from ceb10n/features/organize-aws
Browse files Browse the repository at this point in the history
Add boto3 client cache
  • Loading branch information
ceb10n authored Jul 25, 2024
2 parents c8aca4e + ab4ef4a commit d89b088
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 11 deletions.
12 changes: 11 additions & 1 deletion pydantic_settings_aws/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

ClientParam = Literal["secrets_client", "ssm_client"]

_client_cache: Dict[str, Any] = {}


def get_ssm_content(
settings: Type[BaseSettings],
Expand Down Expand Up @@ -153,8 +155,16 @@ def _create_boto3_client(session_args: AwsSession, service: AWSService): # type
Returns:
boto3.client: An aws service boto3 client.
"""
cache_key = service + "_" + session_args.session_key()

if cache_key in _client_cache:
return _client_cache[cache_key]

session: boto3.Session = boto3.Session(
**session_args.model_dump(by_alias=True, exclude_none=True)
)

return session.client(service)
client = session.client(service)
_client_cache[cache_key] = client

return client
15 changes: 15 additions & 0 deletions pydantic_settings_aws/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,18 @@ class AwsSession(BaseModel):
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None

def session_key(self) -> str:
key = ""
for k in self.model_fields.keys():
# session token is too long
if k != "aws_session_token":
v = getattr(self, k)
if v:
key += f"{v}_"
print(key)

if not key:
key = "default"

return key.rstrip("_")
19 changes: 11 additions & 8 deletions tests/aws_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@

TARGET_SESSION = "pydantic_settings_aws.aws.boto3.Session"

TARGET_SECRETS_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_secrets_boto3_client"

TARGET_SSM_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_ssm_boto3_client"

TARGET_SECRETS_CLIENT = "pydantic_settings_aws.aws._create_boto3_client"

TARGET_CREATE_CLIENT_FROM_SETTINGS = "pydantic_settings_aws.aws._create_client_from_settings"
TARGET_CREATE_CLIENT_FROM_SETTINGS = (
"pydantic_settings_aws.aws._create_client_from_settings"
)

TARGET_SECRET_CONTENT = "pydantic_settings_aws.aws._get_secrets_content"

Expand All @@ -23,7 +19,14 @@ def mock_secrets_content_empty(*args):
return ClientMock(secret_string=None)


def mock_ssm(*args):
def mock_ssm(
region_name=None,
profile_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
*args
):
return ClientMock(ssm_value="value")


Expand Down
30 changes: 28 additions & 2 deletions tests/aws_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from .aws_mocks import (
TARGET_CREATE_CLIENT_FROM_SETTINGS,
TARGET_SECRET_CONTENT,
TARGET_SECRETS_BOTO3_CLIENT,
TARGET_SECRETS_CLIENT,
TARGET_SESSION,
BaseSettingsMock,
mock_create_client,
Expand All @@ -23,6 +21,7 @@

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_parameter_name(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", "my/parameter/name")
Expand All @@ -33,6 +32,7 @@ def test_get_ssm_content_must_return_parameter_content_if_annotated_with_paramet

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_args(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name"})
Expand All @@ -43,6 +43,7 @@ def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_ar

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_use_client_if_present_in_metadata(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", {"ssm": "my/parameter/name", "ssm_client": mock_ssm()})
Expand All @@ -53,6 +54,7 @@ def test_get_ssm_content_must_use_client_if_present_in_metadata(*args):

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
parameter_value = aws.get_ssm_content(settings, "field", None)
Expand All @@ -63,6 +65,7 @@ def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args):

@mock.patch(TARGET_SESSION, SessionMock)
def test_create_ssm_client(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
client = aws._create_client_from_settings(settings, "ssm", "ssm_client")
Expand All @@ -72,6 +75,7 @@ def test_create_ssm_client(*args):

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_create_client)
def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._create_client_from_settings(settings, "ssm", "ssm_client")
Expand All @@ -81,6 +85,7 @@ def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args):

@mock.patch(TARGET_SESSION, SessionMock)
def test_create_secrets_client(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
client = aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")
Expand All @@ -90,6 +95,7 @@ def test_create_secrets_client(*args):

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_create_client)
def test_get_secrets_boto3_client_must_create_a_client_if_its_not_given(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")
Expand All @@ -102,6 +108,7 @@ def test_get_secrets_boto3_client_must_create_a_client_if_its_not_given(*args):
def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(
*args,
):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {
"secrets_name": "secrets/name",
Expand All @@ -115,6 +122,7 @@ def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(

@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_secrets_content_invalid_json)
def test_should_not_obfuscate_json_error_in_case_of_invalid_secrets(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {
"secrets_name": "secrets/name",
Expand All @@ -127,6 +135,7 @@ def test_should_not_obfuscate_json_error_in_case_of_invalid_secrets(*args):


def test_get_secrets_content_must_get_binary_content_if_string_is_not_set(*args):
aws._client_cache = {}
content = {
"SecretBinary": json.dumps({"username": "admin"}).encode("utf-8")
}
Expand All @@ -136,6 +145,7 @@ def test_get_secrets_content_must_get_binary_content_if_string_is_not_set(*args)


def test_get_secrets_content_must_not_hide_decode_error_if_not_binary_in_secret_binary(*args):
aws._client_cache = {}
content = {
"SecretBinary": json.dumps({"username": "admin"})
}
Expand All @@ -145,12 +155,14 @@ def test_get_secrets_content_must_not_hide_decode_error_if_not_binary_in_secret_


def test_get_secrets_content_must_return_none_if_neither_string_nor_binary_are_present(*args):
aws._client_cache = {}
secret_content = aws._get_secrets_content({})

assert secret_content is None


def test_get_secrets_content_must_return_none_if_binary_is_present_but_none(*args):
aws._client_cache = {}
content = {
"SecretBinary": None
}
Expand All @@ -160,8 +172,22 @@ def test_get_secrets_content_must_return_none_if_binary_is_present_but_none(*arg


def test_get_secrets_args_must_not_shadow_pydantic_validation_if_required_args_are_not_present(*args):
aws._client_cache = {}
settings = BaseSettingsMock()
settings.model_config = {}

with pytest.raises(ValidationError):
aws._get_secrets_args(settings)


@mock.patch(TARGET_SESSION, mock_ssm)
def test_must_cache_boto3_clients_for_the_same_service_region_and_account(*args):
aws._client_cache = {}

settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")
aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")
aws._create_client_from_settings(settings, "ssm", "ssm_client")

assert len(aws._client_cache) == 2
3 changes: 3 additions & 0 deletions tests/boto3_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(
self.secret_bytes = secret_bytes
self.ssm_value = ssm_value

def client(self, *args):
return self

def get_parameter(self, Name=None, WithDecryption=None):
return {
"Parameter": {
Expand Down
6 changes: 6 additions & 0 deletions tests/models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic_settings_aws.models import AwsSession


def test_aws_session_key_must_be_default_if_all_values_are_none():
session = AwsSession()
assert session.session_key() == "default"
1 change: 1 addition & 0 deletions tests/settings_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ParameterWithTwoSSMClientSettings(ParameterStoreBaseSettings):
)

my_ssm: Annotated[str, {"ssm": "my/parameter", "ssm_client": ClientMock(ssm_value="value")}]
my_ssm_1: Annotated[str, {"ssm": "my/parameter", "ssm_client": ClientMock(ssm_value="value1")}]
my_ssm_2: Annotated[str, "my/ssm/2/parameter"]


Expand Down

0 comments on commit d89b088

Please sign in to comment.