Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCloudSec committed Jan 21, 2025
1 parent 8d01eca commit caa7de0
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 147 deletions.
66 changes: 26 additions & 40 deletions prowler/providers/aws/services/opensearch/opensearch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ def __init__(self, provider):
super().__init__("opensearch", provider)
self.opensearch_domains = {}
self.__threading_call__(self._list_domain_names)
self.__threading_call__(
self._describe_domain_config, self.opensearch_domains.values()
)
self.__threading_call__(self._describe_domain, self.opensearch_domains.values())
self.__threading_call__(self._list_tags, self.opensearch_domains.values())

Expand All @@ -38,43 +35,6 @@ def _list_domain_names(self, regional_client):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

def _describe_domain_config(self, domain):
logger.info("OpenSearch - describing domain configurations...")
try:
regional_client = self.regional_clients[domain.region]
describe_domain = regional_client.describe_domain_config(
DomainName=domain.name
)
for logging_key in [
"SEARCH_SLOW_LOGS",
"INDEX_SLOW_LOGS",
"AUDIT_LOGS",
]:
if logging_key in describe_domain["DomainConfig"].get(
"LogPublishingOptions", {}
).get("Options", {}):
domain.logging.append(
PublishingLoggingOption(
name=logging_key,
enabled=describe_domain["DomainConfig"][
"LogPublishingOptions"
]["Options"][logging_key]["Enabled"],
)
)
try:
domain.access_policy = loads(
describe_domain["DomainConfig"]["AccessPolicies"]["Options"]
)
except JSONDecodeError as error:
logger.warning(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

def _describe_domain(self, domain):
logger.info("OpenSearch - describing domain configurations...")
try:
Expand Down Expand Up @@ -129,6 +89,32 @@ def _describe_domain(self, domain):
domain.dedicated_master_count = cluster_config.get(
"DedicatedMasterCount", 0
)
for logging_key in [
"SEARCH_SLOW_LOGS",
"INDEX_SLOW_LOGS",
"AUDIT_LOGS",
]:
if logging_key in describe_domain["DomainStatus"].get(
"LogPublishingOptions", {}
):
domain.logging.append(
PublishingLoggingOption(
name=logging_key,
enabled=describe_domain["DomainStatus"][
"LogPublishingOptions"
][logging_key]["Enabled"],
)
)
try:
if describe_domain["DomainStatus"].get("AccessPolicies"):
domain.access_policy = loads(
describe_domain["DomainStatus"]["AccessPolicies"]
)
except JSONDecodeError as error:
logger.warning(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ def test_no_domains(self):
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand All @@ -102,26 +105,25 @@ def test_no_domains(self):
@mock_aws
def test_policy_data_restricted(self):
opensearch_client = client("opensearch", region_name=AWS_REGION_US_WEST_2)
domain_arn = opensearch_client.create_domain(DomainName=domain_name)[
"DomainStatus"
]["ARN"]
opensearch_client.update_domain_config(
DomainName=domain_name,
AccessPolicies=str(policy_data_restricted),
)
domain_arn = opensearch_client.create_domain(
DomainName=domain_name, AccessPolicies=str(policy_data_restricted)
)["DomainStatus"]["ARN"]

aws_provider = set_mocked_aws_provider([AWS_REGION_US_WEST_2])

from prowler.providers.aws.services.opensearch.opensearch_service import (
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand All @@ -143,26 +145,25 @@ def test_policy_data_restricted(self):
@mock_aws
def test_policy_data_not_restricted_with_principal_AWS(self):
opensearch_client = client("opensearch", region_name=AWS_REGION_US_WEST_2)
domain_arn = opensearch_client.create_domain(DomainName=domain_name)[
"DomainStatus"
]["ARN"]
opensearch_client.update_domain_config(
DomainName=domain_name,
AccessPolicies=dumps(policy_data_not_restricted),
)
domain_arn = opensearch_client.create_domain(
DomainName=domain_name, AccessPolicies=dumps(policy_data_not_restricted)
)["DomainStatus"]["ARN"]

aws_provider = set_mocked_aws_provider([AWS_REGION_US_WEST_2])

from prowler.providers.aws.services.opensearch.opensearch_service import (
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand All @@ -184,26 +185,26 @@ def test_policy_data_not_restricted_with_principal_AWS(self):
@mock_aws
def test_policy_data_not_restricted_with_principal_no_AWS(self):
opensearch_client = client("opensearch", region_name=AWS_REGION_US_WEST_2)
domain_arn = opensearch_client.create_domain(DomainName=domain_name)[
"DomainStatus"
]["ARN"]
opensearch_client.update_domain_config(
domain_arn = opensearch_client.create_domain(
DomainName=domain_name,
AccessPolicies=dumps(policy_data_not_restricted_principal),
)
)["DomainStatus"]["ARN"]

aws_provider = set_mocked_aws_provider([AWS_REGION_US_WEST_2])

from prowler.providers.aws.services.opensearch.opensearch_service import (
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand All @@ -225,26 +226,26 @@ def test_policy_data_not_restricted_with_principal_no_AWS(self):
@mock_aws
def test_policy_data_not_restricted_ip_full(self):
opensearch_client = client("opensearch", region_name=AWS_REGION_US_WEST_2)
domain_arn = opensearch_client.create_domain(DomainName=domain_name)[
"DomainStatus"
]["ARN"]
opensearch_client.update_domain_config(
domain_arn = opensearch_client.create_domain(
DomainName=domain_name,
AccessPolicies=dumps(policy_data_source_ip_full),
)
)["DomainStatus"]["ARN"]

aws_provider = set_mocked_aws_provider([AWS_REGION_US_WEST_2])

from prowler.providers.aws.services.opensearch.opensearch_service import (
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand All @@ -266,26 +267,26 @@ def test_policy_data_not_restricted_ip_full(self):
@mock_aws
def test_policy_data_not_restricted_whole_internet(self):
opensearch_client = client("opensearch", region_name=AWS_REGION_US_WEST_2)
domain_arn = opensearch_client.create_domain(DomainName=domain_name)[
"DomainStatus"
]["ARN"]
opensearch_client.update_domain_config(
domain_arn = opensearch_client.create_domain(
DomainName=domain_name,
AccessPolicies=dumps(policy_data_source_whole_internet),
)
)["DomainStatus"]["ARN"]

aws_provider = set_mocked_aws_provider([AWS_REGION_US_WEST_2])

from prowler.providers.aws.services.opensearch.opensearch_service import (
OpenSearchService,
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible.opensearch_client",
new=OpenSearchService(aws_provider),
),
):
from prowler.providers.aws.services.opensearch.opensearch_service_domains_not_publicly_accessible.opensearch_service_domains_not_publicly_accessible import (
opensearch_service_domains_not_publicly_accessible,
Expand Down
61 changes: 19 additions & 42 deletions tests/providers/aws/services/opensearch/opensearch_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,6 @@ def mock_make_api_call(self, operation_name, kwarg):
},
]
}
if operation_name == "DescribeDomainConfig":
return {
"DomainConfig": {
"AccessPolicies": {
"Options": policy_json,
},
"LogPublishingOptions": {
"Options": {
"SEARCH_SLOW_LOGS": {"Enabled": True},
"INDEX_SLOW_LOGS": {"Enabled": True},
"AUDIT_LOGS": {"Enabled": True},
},
},
}
}
if operation_name == "DescribeDomain":
return {
"DomainStatus": {
Expand All @@ -79,19 +64,19 @@ def mock_make_api_call(self, operation_name, kwarg):
"EncryptionAtRestOptions": {"Enabled": True},
"NodeToNodeEncryptionOptions": {"Enabled": True},
"AdvancedOptions": {"string": "string"},
"LogPublishingOptions": {
"string": {
"CloudWatchLogsLogGroupArn": "string",
"Enabled": True | False,
}
},
"ServiceSoftwareOptions": {"UpdateAvailable": True},
"DomainEndpointOptions": {"EnforceHTTPS": True},
"AdvancedSecurityOptions": {
"Enabled": True,
"InternalUserDatabaseEnabled": True,
"SAMLOptions": {"Enabled": True},
},
"AccessPolicies": policy_json,
"LogPublishingOptions": {
"SEARCH_SLOW_LOGS": {"Enabled": True},
"INDEX_SLOW_LOGS": {"Enabled": True},
"AUDIT_LOGS": {"Enabled": True},
},
}
}
if operation_name == "ListTags":
Expand Down Expand Up @@ -144,27 +129,6 @@ def test_list_domain_names(self):
assert opensearch.opensearch_domains[domain_arn].name == test_domain_name
assert opensearch.opensearch_domains[domain_arn].region == AWS_REGION_EU_WEST_1

# Test OpenSearchService describe domain config
def test_describe_domain_config(self):
aws_provider = set_mocked_aws_provider([])
opensearch = OpenSearchService(aws_provider)
assert len(opensearch.opensearch_domains) == 1
assert opensearch.opensearch_domains[domain_arn].name == test_domain_name
assert opensearch.opensearch_domains[domain_arn].region == AWS_REGION_EU_WEST_1
assert opensearch.opensearch_domains[domain_arn].access_policy
assert (
opensearch.opensearch_domains[domain_arn].logging[0].name
== "SEARCH_SLOW_LOGS"
)
assert opensearch.opensearch_domains[domain_arn].logging[0].enabled
assert (
opensearch.opensearch_domains[domain_arn].logging[1].name
== "INDEX_SLOW_LOGS"
)
assert opensearch.opensearch_domains[domain_arn].logging[1].enabled
assert opensearch.opensearch_domains[domain_arn].logging[2].name == "AUDIT_LOGS"
assert opensearch.opensearch_domains[domain_arn].logging[2].enabled

# Test OpenSearchService describe domain
@mock_aws
def test_describe_domain(self):
Expand Down Expand Up @@ -193,6 +157,19 @@ def test_describe_domain(self):
assert opensearch.opensearch_domains[domain_arn].zone_awareness_enabled
assert opensearch.opensearch_domains[domain_arn].dedicated_master_enabled
assert opensearch.opensearch_domains[domain_arn].dedicated_master_count == 1
assert opensearch.opensearch_domains[domain_arn].access_policy
assert (
opensearch.opensearch_domains[domain_arn].logging[0].name
== "SEARCH_SLOW_LOGS"
)
assert opensearch.opensearch_domains[domain_arn].logging[0].enabled
assert (
opensearch.opensearch_domains[domain_arn].logging[1].name
== "INDEX_SLOW_LOGS"
)
assert opensearch.opensearch_domains[domain_arn].logging[1].enabled
assert opensearch.opensearch_domains[domain_arn].logging[2].name == "AUDIT_LOGS"
assert opensearch.opensearch_domains[domain_arn].logging[2].enabled
assert opensearch.opensearch_domains[domain_arn].tags == [
{"Key": "test", "Value": "test"},
]

0 comments on commit caa7de0

Please sign in to comment.