From 86c64dfa0f04a1f9f3121f896f73c3fcb8832b05 Mon Sep 17 00:00:00 2001 From: leohoare Date: Wed, 22 Jan 2025 20:56:16 +1100 Subject: [PATCH] add test coverage, async providers calling sync calls, async only client Signed-off-by: leohoare --- tests/provider/test_in_memory_provider.py | 4 +- tests/provider/test_provider_compatibility.py | 196 ++++++++++++++++++ tests/test_client.py | 140 ++++++++++--- 3 files changed, 309 insertions(+), 31 deletions(-) create mode 100644 tests/provider/test_provider_compatibility.py diff --git a/tests/provider/test_in_memory_provider.py b/tests/provider/test_in_memory_provider.py index f3559363..cdcea7bf 100644 --- a/tests/provider/test_in_memory_provider.py +++ b/tests/provider/test_in_memory_provider.py @@ -154,7 +154,9 @@ async def test_should_resolve_list_flag_from_in_memory(): ) # When flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[]) - flag_async = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_async = await provider.resolve_object_details_async( + flag_key="Key", default_value=[] + ) # Then assert flag_sync == flag_async for flag in [flag_sync, flag_async]: diff --git a/tests/provider/test_provider_compatibility.py b/tests/provider/test_provider_compatibility.py new file mode 100644 index 00000000..d90859c4 --- /dev/null +++ b/tests/provider/test_provider_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from typing import Optional, Union + +import pytest + +from openfeature.api import OpenFeatureClient, get_client, set_provider +from openfeature.evaluation_context import EvaluationContext +from openfeature.flag_evaluation import FlagResolutionDetails +from openfeature.provider import AbstractProvider, Metadata + + +class SynchronousProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="SynchronousProvider") + + def get_provider_hooks(self): + return [] + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails(value="string") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails(value=1) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails(value=10.0) + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + return FlagResolutionDetails(value={"key": "value"}) + + +@pytest.mark.parametrize( + "flag_type, default_value, get_method", + ( + (bool, True, "get_boolean_value_async"), + (str, "string", "get_string_value_async"), + (int, 1, "get_integer_value_async"), + (float, 10.0, "get_float_value_async"), + ( + dict, + {"key": "value"}, + "get_object_value_async", + ), + ), +) +@pytest.mark.asyncio +async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method): + # Given + set_provider(SynchronousProvider(), "SynchronousProvider") + client = get_client("SynchronousProvider") + # When + async_callable = getattr(client, get_method) + flag = await async_callable(flag_key="Key", default_value=default_value) + # Then + assert flag is not None + assert flag == default_value + assert isinstance(flag, flag_type) + + +@pytest.mark.asyncio +async def test_sync_provider_can_be_extended_async(): + # Given + class ExtendedAsyncProvider(SynchronousProvider): + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=False) + + set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider") + client = get_client("ExtendedAsyncProvider") + # When + flag = await client.get_boolean_value_async(flag_key="Key", default_value=True) + # Then + assert flag is not None + assert flag is False + + +# We're not allowing providers to only have async methods +def test_sync_methods_enforced_for_async_providers(): + # Given + class AsyncProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + # When + with pytest.raises(TypeError) as exception: + set_provider(AsyncProvider(), "AsyncProvider") + + # Then + # assert + assert str(exception.value).startswith( + "Can't instantiate abstract class AsyncProvider with abstract methods resolve_boolean_details" + ) + + +@pytest.mark.asyncio +async def test_async_provider_not_implemented_exception_workaround(): + # Given + class SyncNotImplementedProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + raise NotImplementedError("Use the async method") + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + raise NotImplementedError("Use the async method") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + raise NotImplementedError("Use the async method") + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + raise NotImplementedError("Use the async method") + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + raise NotImplementedError("Use the async method") + + # When + set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider") + client = get_client("SyncNotImplementedProvider") + flag = await client.get_boolean_value_async(flag_key="Key", default_value=False) + # Then + assert flag is not None + assert flag is True diff --git a/tests/test_client.py b/tests/test_client.py index 4c00a95a..695a92c1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider -from openfeature.client import OpenFeatureClient +from openfeature.client import GeneralError, OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError @@ -24,9 +24,13 @@ "flag_type, default_value, get_method", ( (bool, True, "get_boolean_value"), + (bool, True, "get_boolean_value_async"), (str, "String", "get_string_value"), + (str, "String", "get_string_value_async"), (int, 100, "get_integer_value"), + (int, 100, "get_integer_value_async"), (float, 10.23, "get_float_value"), + (float, 10.23, "get_float_value_async"), ( dict, { @@ -36,21 +40,38 @@ }, "get_object_value", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_value_async", + ), ( list, ["string1", "string2"], "get_object_value", ), + ( + list, + ["string1", "string2"], + "get_object_value_async", + ), ), ) -def test_should_get_flag_value_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_value_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag == default_value @@ -115,19 +136,24 @@ async def test_should_get_flag_detail_based_on_method_type( assert isinstance(flag.value, flag_type) -def test_should_raise_exception_when_invalid_flag_type_provided( +@pytest.mark.asyncio +async def test_should_raise_exception_when_invalid_flag_type_provided( no_op_provider_client, ): # Given # When - flag = no_op_provider_client.evaluate_flag_details( + flag_sync = no_op_provider_client.evaluate_flag_details( + flag_type=None, flag_key="Key", default_value=True + ) + flag_async = await no_op_provider_client.evaluate_flag_details_async( flag_type=None, flag_key="Key", default_value=True ) # Then - assert flag.value - assert flag.error_message == "Unknown flag type" - assert flag.error_code == ErrorCode.GENERAL - assert flag.reason == Reason.ERROR + for flag in [flag_sync, flag_async]: + assert flag.value + assert flag.error_message == "Unknown flag type" + assert flag.error_code == ErrorCode.GENERAL + assert flag.reason == Reason.ERROR def test_should_pass_flag_metadata_from_resolution_to_evaluation_details(): @@ -226,7 +252,8 @@ def test_should_define_a_provider_status_accessor(no_op_provider_client): # Requirement 1.7.6 -def test_should_shortcircuit_if_provider_is_not_ready( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): # Given @@ -236,19 +263,26 @@ def test_should_shortcircuit_if_provider_is_not_ready( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY spy_hook.error.assert_called_once() # Requirement 1.7.7 -def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): # Given @@ -258,41 +292,87 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() -def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): +@pytest.mark.asyncio +async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): # Given spy_hook = MagicMock(spec=Hook) provider = MagicMock(spec=FeatureProvider) provider.get_provider_hooks.return_value = [] - provider.resolve_boolean_details.return_value = FlagResolutionDetails( + mock_resolution = FlagResolutionDetails( value=True, reason=Reason.ERROR, error_code=ErrorCode.PROVIDER_FATAL, error_message="This is an error message", ) + provider.resolve_boolean_details.return_value = mock_resolution + provider.resolve_boolean_details_async.return_value = mock_resolution set_provider(provider) client = get_client() client.add_hooks([spy_hook]) # When - flag_details = client.get_boolean_details(flag_key="Key", default_value=True) + flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value=True + ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() +@pytest.mark.asyncio +async def test_client_type_mismatch_exceptions(): + # Given + client = get_client() + # When + flag_details_sync = client.get_boolean_details( + flag_key="Key", default_value="type mismatch" + ) + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value="type mismatch" + ) + # Then + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.TYPE_MISMATCH + + +@pytest.mark.asyncio +async def test_client_general_exception(): + # Given + flag_value = "A" + flag_type = None + # When + with pytest.raises(GeneralError) as e: + flag_type = _typecheck_flag_value(flag_value, flag_type) + # Then + assert e.value.error_message == "Unknown flag type" + + def test_provider_events(): # Given provider = NoOpProvider()