diff --git a/README.md b/README.md index 5cc194d1..8fbb6b8a 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,25 @@ async def some_endpoint(): return create_response() ``` +### Asynchronous Feature Retrieval + +The OpenFeature API supports asynchronous calls, enabling non-blocking feature evaluations for improved performance, especially useful in concurrent or latency-sensitive scenarios. If a provider *hasn't* implemented asynchronous calls, the client can still be used asynchronously, but calls will be blocking (synchronous). + +```python +import asyncio +from openfeature import api +from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider + +my_flags = { "v2_enabled": InMemoryFlag("on", {"on": True, "off": False}) } +api.set_provider(InMemoryProvider(my_flags)) +client = api.get_client() +flag_value = await client.get_boolean_value_async("v2_enabled", False) # API calls are suffixed by _async + +print("Value: " + str(flag_value)) +``` + +See the [develop a provider](#develop-a-provider) for how to support asynchronous functionality in providers. + ### Shutdown The OpenFeature API provides a shutdown function to perform a cleanup of all registered providers. This should only be called when your application is in the process of shutting down. @@ -390,6 +409,56 @@ class MyProvider(AbstractProvider): ... ``` +Providers can also be extended to support async functionality. +To support add asynchronous calls to a provider: +* Implement the `AbstractProvider` as shown above. +* Define asynchronous calls for each data type. + +```python +class MyProvider(AbstractProvider): + ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + ... + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + ... + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + ... + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + ... + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + ... + +``` + > Built a new provider? [Let us know](https://github.com/open-feature/openfeature.dev/issues/new?assignees=&labels=provider&projects=&template=document-provider.yaml&title=%5BProvider%5D%3A+) so we can add it to the docs! ### Develop a hook diff --git a/openfeature/client.py b/openfeature/client.py index cd82694b..326359eb 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -20,7 +20,7 @@ FlagType, Reason, ) -from openfeature.hook import Hook, HookContext +from openfeature.hook import Hook, HookContext, HookHints from openfeature.hook._hook_support import ( after_all_hooks, after_hooks, @@ -55,6 +55,28 @@ FlagResolutionDetails[typing.Union[dict, list]], ], ] +GetDetailCallableAsync = typing.Union[ + typing.Callable[ + [str, bool, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[bool]], + ], + typing.Callable[ + [str, int, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[int]], + ], + typing.Callable[ + [str, float, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[float]], + ], + typing.Callable[ + [str, str, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[str]], + ], + typing.Callable[ + [str, typing.Union[dict, list], typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]], + ], +] TypeMap = typing.Dict[ FlagType, typing.Union[ @@ -113,6 +135,21 @@ def get_boolean_value( flag_evaluation_options, ).value + async def get_boolean_value_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + details = await self.get_boolean_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_boolean_details( self, flag_key: str, @@ -128,6 +165,21 @@ def get_boolean_details( flag_evaluation_options, ) + async def get_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_string_value( self, flag_key: str, @@ -142,6 +194,21 @@ def get_string_value( flag_evaluation_options, ).value + async def get_string_value_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + details = await self.get_string_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_string_details( self, flag_key: str, @@ -157,6 +224,21 @@ def get_string_details( flag_evaluation_options, ) + async def get_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_integer_value( self, flag_key: str, @@ -171,6 +253,21 @@ def get_integer_value( flag_evaluation_options, ).value + async def get_integer_value_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + details = await self.get_integer_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_integer_details( self, flag_key: str, @@ -186,6 +283,21 @@ def get_integer_details( flag_evaluation_options, ) + async def get_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_float_value( self, flag_key: str, @@ -200,6 +312,21 @@ def get_float_value( flag_evaluation_options, ).value + async def get_float_value_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + details = await self.get_float_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_float_details( self, flag_key: str, @@ -215,6 +342,21 @@ def get_float_details( flag_evaluation_options, ) + async def get_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_object_value( self, flag_key: str, @@ -229,6 +371,21 @@ def get_object_value( flag_evaluation_options, ).value + async def get_object_value_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + details = await self.get_object_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_object_details( self, flag_key: str, @@ -244,26 +401,35 @@ def get_object_details( flag_evaluation_options, ) - def evaluate_flag_details( # noqa: PLR0915 + async def get_object_details_async( self, - flag_type: FlagType, flag_key: str, - default_value: typing.Any, + default_value: typing.Union[dict, list], evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[typing.Any]: - """ - Evaluate the flag requested by the user from the clients provider. - - :param flag_type: the type of the flag being returned - :param flag_key: the string key of the selected flag - :param default_value: backup value returned if no result found by the provider - :param evaluation_context: Information for the purposes of flag evaluation - :param flag_evaluation_options: Additional flag evaluation information - :return: a FlagEvaluationDetails object with the fully evaluated flag from a - provider - """ + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def _establish_hooks_and_provider( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext], + flag_evaluation_options: typing.Optional[FlagEvaluationOptions], + ) -> typing.Tuple[ + FeatureProvider, + HookContext, + HookHints, + typing.List[Hook], + typing.List[Hook], + ]: if evaluation_context is None: evaluation_context = EvaluationContext() @@ -295,54 +461,179 @@ def evaluate_flag_details( # noqa: PLR0915 reversed_merged_hooks = merged_hooks[:] reversed_merged_hooks.reverse() + return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks + + def _assert_provider_status( + self, + ) -> None: + status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: + raise ProviderNotReadyError() + if status == ProviderStatus.FATAL: + raise ProviderFatalError() + return None + + def _before_hooks_and_merge_context( + self, + flag_type: FlagType, + hook_context: HookContext, + merged_hooks: typing.List[Hook], + hook_hints: HookHints, + evaluation_context: typing.Optional[EvaluationContext], + ) -> EvaluationContext: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = before_hooks( + flag_type, hook_context, merged_hooks, hook_hints + ) + if evaluation_context: + invocation_context = invocation_context.merge(ctx2=evaluation_context) + + # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context() + .merge(api.get_transaction_context()) + .merge(self.context) + .merge(invocation_context) + ) + return merged_context + + async def evaluate_flag_details_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + try: - status = self.get_provider_status() - if status == ProviderStatus.NOT_READY: - error_hooks( - flag_type, - hook_context, - ProviderNotReadyError(), - reversed_merged_hooks, - hook_hints, - ) - flag_evaluation = FlagEvaluationDetails( - flag_key=flag_key, - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_NOT_READY, - ) - return flag_evaluation - if status == ProviderStatus.FATAL: - error_hooks( - flag_type, - hook_context, - ProviderFatalError(), - reversed_merged_hooks, - hook_hints, - ) - flag_evaluation = FlagEvaluationDetails( - flag_key=flag_key, - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_FATAL, - ) - return flag_evaluation - - # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md - # Any resulting evaluation context from a before hook will overwrite - # duplicate fields defined globally, on the client, or in the invocation. - # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context - invocation_context = before_hooks( - flag_type, hook_context, merged_hooks, hook_hints + self._assert_provider_status() + + merged_context = self._before_hooks_and_merge_context( + flag_type, + hook_context, + merged_hooks, + hook_hints, + evaluation_context, + ) + + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, ) - invocation_context = invocation_context.merge(ctx2=evaluation_context) - # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context - merged_context = ( - api.get_evaluation_context() - .merge(api.get_transaction_context()) - .merge(self.context) - .merge(invocation_context) + after_hooks( + flag_type, + hook_context, + flag_evaluation, + reversed_merged_hooks, + hook_hints, + ) + + return flag_evaluation + + except OpenFeatureError as err: + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + flag_evaluation = FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=err.error_code, + error_message=err.error_message, + ) + return flag_evaluation + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key + ) + + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + error_message = getattr(err, "error_message", str(err)) + flag_evaluation = FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + return flag_evaluation + + finally: + after_all_hooks( + flag_type, + hook_context, + flag_evaluation, + reversed_merged_hooks, + hook_hints, + ) + + def evaluate_flag_details( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + + try: + self._assert_provider_status() + + merged_context = self._before_hooks_and_merge_context( + flag_type, + hook_context, + merged_hooks, + hook_hints, + evaluation_context, ) flag_evaluation = self._create_provider_evaluation( @@ -402,6 +693,48 @@ def evaluate_flag_details( # noqa: PLR0915 hook_hints, ) + async def _create_provider_evaluation_async( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + args = ( + flag_key, + default_value, + evaluation_context, + ) + get_details_callables_async: typing.Mapping[ + FlagType, GetDetailCallableAsync + ] = { + FlagType.BOOLEAN: provider.resolve_boolean_details_async, + FlagType.INTEGER: provider.resolve_integer_details_async, + FlagType.FLOAT: provider.resolve_float_details_async, + FlagType.OBJECT: provider.resolve_object_details_async, + FlagType.STRING: provider.resolve_string_details_async, + } + get_details_callable = get_details_callables_async.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) + def _create_provider_evaluation( self, provider: FeatureProvider, diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index b390f928..6a782635 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -47,6 +47,13 @@ def resolve_boolean_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[bool]: ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: ... + def resolve_string_details( self, flag_key: str, @@ -54,6 +61,13 @@ def resolve_string_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[str]: ... + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: ... + def resolve_integer_details( self, flag_key: str, @@ -61,6 +75,13 @@ def resolve_integer_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[int]: ... + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: ... + def resolve_float_details( self, flag_key: str, @@ -68,6 +89,13 @@ def resolve_float_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[float]: ... + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: ... + def resolve_object_details( self, flag_key: str, @@ -75,6 +103,13 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + class AbstractProvider(FeatureProvider): def attach( @@ -111,6 +146,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: pass + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self.resolve_boolean_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_string_details( self, @@ -120,6 +163,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: pass + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self.resolve_string_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_integer_details( self, @@ -129,6 +180,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: pass + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self.resolve_integer_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_float_details( self, @@ -138,6 +197,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: pass + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self.resolve_float_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_object_details( self, @@ -147,6 +214,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self.resolve_object_details(flag_key, default_value, evaluation_context) + def emit_provider_ready(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_READY, details) diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 322f4ed6..d64a7735 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -76,6 +76,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: return self._resolve(flag_key, evaluation_context) + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_string_details( self, flag_key: str, @@ -84,6 +92,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: return self._resolve(flag_key, evaluation_context) + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_integer_details( self, flag_key: str, @@ -92,6 +108,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: return self._resolve(flag_key, evaluation_context) + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_float_details( self, flag_key: str, @@ -100,6 +124,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: return self._resolve(flag_key, evaluation_context) + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_object_details( self, flag_key: str, @@ -108,6 +140,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: return self._resolve(flag_key, evaluation_context) + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return await self._resolve_async(flag_key, evaluation_context) + def _resolve( self, flag_key: str, @@ -117,3 +157,10 @@ def _resolve( if flag is None: raise FlagNotFoundError(f"Flag '{flag_key}' not found") return flag.resolve(evaluation_context) + + async def _resolve_async( + self, + flag_key: str, + evaluation_context: typing.Optional[EvaluationContext], + ) -> FlagResolutionDetails[V]: + return self._resolve(flag_key, evaluation_context) diff --git a/tests/provider/test_in_memory_provider.py b/tests/provider/test_in_memory_provider.py index 66d5239e..cdcea7bf 100644 --- a/tests/provider/test_in_memory_provider.py +++ b/tests/provider/test_in_memory_provider.py @@ -17,16 +17,20 @@ def test_should_return_in_memory_provider_metadata(): assert metadata.name == "In-Memory Provider" -def test_should_handle_unknown_flags_correctly(): +@pytest.mark.asyncio +async def test_should_handle_unknown_flags_correctly(): # Given provider = InMemoryProvider({}) # When with pytest.raises(FlagNotFoundError): provider.resolve_boolean_details(flag_key="Key", default_value=True) + with pytest.raises(FlagNotFoundError): + await provider.resolve_integer_details_async(flag_key="Key", default_value=1) # Then -def test_calls_context_evaluator_if_present(): +@pytest.mark.asyncio +async def test_calls_context_evaluator_if_present(): # Given def context_evaluator(flag: InMemoryFlag, evaluation_context: dict): return FlagResolutionDetails( @@ -44,57 +48,81 @@ def context_evaluator(flag: InMemoryFlag, evaluation_context: dict): } ) # When - flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_async = await provider.resolve_boolean_details_async( + flag_key="Key", default_value=False + ) # Then - assert flag is not None - assert flag.value is False - assert isinstance(flag.value, bool) - assert flag.reason == Reason.TARGETING_MATCH + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value is False + assert isinstance(flag.value, bool) + assert flag.reason == Reason.TARGETING_MATCH -def test_should_resolve_boolean_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_boolean_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("true", {"true": True, "false": False})} ) # When - flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_async = await provider.resolve_boolean_details_async( + flag_key="Key", default_value=False + ) # Then - assert flag is not None - assert flag.value is True - assert isinstance(flag.value, bool) - assert flag.variant == "true" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value is True + assert isinstance(flag.value, bool) + assert flag.variant == "true" -def test_should_resolve_integer_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_integer_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("hundred", {"zero": 0, "hundred": 100})} ) # When - flag = provider.resolve_integer_details(flag_key="Key", default_value=0) + flag_sync = provider.resolve_integer_details(flag_key="Key", default_value=0) + flag_async = await provider.resolve_integer_details_async( + flag_key="Key", default_value=0 + ) # Then - assert flag is not None - assert flag.value == 100 - assert isinstance(flag.value, Number) - assert flag.variant == "hundred" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == 100 + assert isinstance(flag.value, Number) + assert flag.variant == "hundred" -def test_should_resolve_float_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_float_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("ten", {"zero": 0.0, "ten": 10.23})} ) # When - flag = provider.resolve_float_details(flag_key="Key", default_value=0.0) + flag_sync = provider.resolve_float_details(flag_key="Key", default_value=0.0) + flag_async = await provider.resolve_float_details_async( + flag_key="Key", default_value=0.0 + ) # Then - assert flag is not None - assert flag.value == 10.23 - assert isinstance(flag.value, Number) - assert flag.variant == "ten" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == 10.23 + assert isinstance(flag.value, Number) + assert flag.variant == "ten" -def test_should_resolve_string_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_string_flag_from_in_memory(): # Given provider = InMemoryProvider( { @@ -105,29 +133,41 @@ def test_should_resolve_string_flag_from_in_memory(): } ) # When - flag = provider.resolve_string_details(flag_key="Key", default_value="Default") + flag_sync = provider.resolve_string_details(flag_key="Key", default_value="Default") + flag_async = await provider.resolve_string_details_async( + flag_key="Key", default_value="Default" + ) # Then - assert flag is not None - assert flag.value == "String" - assert isinstance(flag.value, str) - assert flag.variant == "stringVariant" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == "String" + assert isinstance(flag.value, str) + assert flag.variant == "stringVariant" -def test_should_resolve_list_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_list_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("twoItems", {"empty": [], "twoItems": ["item1", "item2"]})} ) # When - flag = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_async = await provider.resolve_object_details_async( + flag_key="Key", default_value=[] + ) # Then - assert flag is not None - assert flag.value == ["item1", "item2"] - assert isinstance(flag.value, list) - assert flag.variant == "twoItems" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == ["item1", "item2"] + assert isinstance(flag.value, list) + assert flag.variant == "twoItems" -def test_should_resolve_object_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_object_flag_from_in_memory(): # Given return_value = { "String": "string", @@ -138,9 +178,12 @@ def test_should_resolve_object_flag_from_in_memory(): {"Key": InMemoryFlag("obj", {"obj": return_value, "empty": {}})} ) # When - flag = provider.resolve_object_details(flag_key="Key", default_value={}) + flag_sync = provider.resolve_object_details(flag_key="Key", default_value={}) + flag_async = provider.resolve_object_details(flag_key="Key", default_value={}) # Then - assert flag is not None - assert flag.value == return_value - assert isinstance(flag.value, dict) - assert flag.variant == "obj" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == return_value + assert isinstance(flag.value, dict) + assert flag.variant == "obj" diff --git a/tests/provider/test_provider_compatibility.py b/tests/provider/test_provider_compatibility.py new file mode 100644 index 00000000..aad87db4 --- /dev/null +++ b/tests/provider/test_provider_compatibility.py @@ -0,0 +1,197 @@ +from typing import Optional, Union + +import pytest + +from openfeature.api import get_client, set_provider +from openfeature.evaluation_context import EvaluationContext +from openfeature.flag_evaluation import FlagResolutionDetails +from openfeature.provider import AbstractProvider, Metadata + + +class SynchronousProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="SynchronousProvider") + + def get_provider_hooks(self): + return [] + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails(value="string") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails(value=1) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails(value=10.0) + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + return FlagResolutionDetails(value={"key": "value"}) + + +@pytest.mark.parametrize( + "flag_type, default_value, get_method", + ( + (bool, True, "get_boolean_value_async"), + (str, "string", "get_string_value_async"), + (int, 1, "get_integer_value_async"), + (float, 10.0, "get_float_value_async"), + ( + dict, + {"key": "value"}, + "get_object_value_async", + ), + ), +) +@pytest.mark.asyncio +async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method): + # Given + set_provider(SynchronousProvider(), "SynchronousProvider") + client = get_client("SynchronousProvider") + # When + async_callable = getattr(client, get_method) + flag = await async_callable(flag_key="Key", default_value=default_value) + # Then + assert flag is not None + assert flag == default_value + assert isinstance(flag, flag_type) + + +@pytest.mark.asyncio +async def test_sync_provider_can_be_extended_async(): + # Given + class ExtendedAsyncProvider(SynchronousProvider): + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=False) + + set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider") + client = get_client("ExtendedAsyncProvider") + # When + flag = await client.get_boolean_value_async(flag_key="Key", default_value=True) + # Then + assert flag is not None + assert flag is False + + +# We're not allowing providers to only have async methods +def test_sync_methods_enforced_for_async_providers(): + # Given + class AsyncProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + # When + with pytest.raises(TypeError) as exception: + set_provider(AsyncProvider(), "AsyncProvider") + + # Then + # assert + exception_message = str(exception.value) + assert exception_message.startswith( + "Can't instantiate abstract class AsyncProvider" + ) + assert exception_message.__contains__("resolve_boolean_details") + + +@pytest.mark.asyncio +async def test_async_provider_not_implemented_exception_workaround(): + # Given + class SyncNotImplementedProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + raise NotImplementedError("Use the async method") + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + raise NotImplementedError("Use the async method") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + raise NotImplementedError("Use the async method") + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + raise NotImplementedError("Use the async method") + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + raise NotImplementedError("Use the async method") + + # When + set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider") + client = get_client("SyncNotImplementedProvider") + flag = await client.get_boolean_value_async(flag_key="Key", default_value=False) + # Then + assert flag is not None + assert flag is True diff --git a/tests/test_client.py b/tests/test_client.py index f6002c18..5d333993 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -7,7 +8,7 @@ from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider -from openfeature.client import OpenFeatureClient +from openfeature.client import GeneralError, OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError @@ -23,9 +24,13 @@ "flag_type, default_value, get_method", ( (bool, True, "get_boolean_value"), + (bool, True, "get_boolean_value_async"), (str, "String", "get_string_value"), + (str, "String", "get_string_value_async"), (int, 100, "get_integer_value"), + (int, 100, "get_integer_value_async"), (float, 10.23, "get_float_value"), + (float, 10.23, "get_float_value_async"), ( dict, { @@ -35,21 +40,38 @@ }, "get_object_value", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_value_async", + ), ( list, ["string1", "string2"], "get_object_value", ), + ( + list, + ["string1", "string2"], + "get_object_value_async", + ), ), ) -def test_should_get_flag_value_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_value_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag == default_value @@ -60,9 +82,13 @@ def test_should_get_flag_value_based_on_method_type( "flag_type, default_value, get_method", ( (bool, True, "get_boolean_details"), + (bool, True, "get_boolean_details_async"), (str, "String", "get_string_details"), + (str, "String", "get_string_details_async"), (int, 100, "get_integer_details"), + (int, 100, "get_integer_details_async"), (float, 10.23, "get_float_details"), + (float, 10.23, "get_float_details_async"), ( dict, { @@ -72,38 +98,62 @@ def test_should_get_flag_value_based_on_method_type( }, "get_object_details", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_details_async", + ), ( list, ["string1", "string2"], "get_object_details", ), + ( + list, + ["string1", "string2"], + "get_object_details_async", + ), ), ) -def test_should_get_flag_detail_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_detail_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag.value == default_value assert isinstance(flag.value, flag_type) -def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client): +@pytest.mark.asyncio +async def test_should_raise_exception_when_invalid_flag_type_provided( + no_op_provider_client, +): # Given # When - flag = no_op_provider_client.evaluate_flag_details( + flag_sync = no_op_provider_client.evaluate_flag_details( + flag_type=None, flag_key="Key", default_value=True + ) + flag_async = await no_op_provider_client.evaluate_flag_details_async( flag_type=None, flag_key="Key", default_value=True ) # Then - assert flag.value - assert flag.error_message == "Unknown flag type" - assert flag.error_code == ErrorCode.GENERAL - assert flag.reason == Reason.ERROR + for flag in [flag_sync, flag_async]: + assert flag.value + assert flag.error_message == "Unknown flag type" + assert flag.error_code == ErrorCode.GENERAL + assert flag.reason == Reason.ERROR def test_should_pass_flag_metadata_from_resolution_to_evaluation_details(): @@ -202,7 +252,8 @@ def test_should_define_a_provider_status_accessor(no_op_provider_client): # Requirement 1.7.6 -def test_should_shortcircuit_if_provider_is_not_ready( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): # Given @@ -212,20 +263,27 @@ def test_should_shortcircuit_if_provider_is_not_ready( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY spy_hook.error.assert_called_once() spy_hook.finally_after.assert_called_once() # Requirement 1.7.7 -def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): # Given @@ -235,42 +293,88 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() spy_hook.finally_after.assert_called_once() -def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): +@pytest.mark.asyncio +async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): # Given spy_hook = MagicMock(spec=Hook) provider = MagicMock(spec=FeatureProvider) provider.get_provider_hooks.return_value = [] - provider.resolve_boolean_details.return_value = FlagResolutionDetails( + mock_resolution = FlagResolutionDetails( value=True, reason=Reason.ERROR, error_code=ErrorCode.PROVIDER_FATAL, error_message="This is an error message", ) + provider.resolve_boolean_details.return_value = mock_resolution + provider.resolve_boolean_details_async.return_value = mock_resolution set_provider(provider) client = get_client() client.add_hooks([spy_hook]) # When - flag_details = client.get_boolean_details(flag_key="Key", default_value=True) + flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value=True + ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() +@pytest.mark.asyncio +async def test_client_type_mismatch_exceptions(): + # Given + client = get_client() + # When + flag_details_sync = client.get_boolean_details( + flag_key="Key", default_value="type mismatch" + ) + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value="type mismatch" + ) + # Then + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.TYPE_MISMATCH + + +@pytest.mark.asyncio +async def test_client_general_exception(): + # Given + flag_value = "A" + flag_type = None + # When + with pytest.raises(GeneralError) as e: + flag_type = _typecheck_flag_value(flag_value, flag_type) + # Then + assert e.value.error_message == "Unknown flag type" + + def test_provider_events(): # Given provider = NoOpProvider()