Skip to content

Commit

Permalink
add test coverage, async providers calling sync calls, async only client
Browse files Browse the repository at this point in the history
Signed-off-by: leohoare <[email protected]>
  • Loading branch information
leohoare committed Jan 22, 2025
1 parent 3338f29 commit 86c64df
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 31 deletions.
4 changes: 3 additions & 1 deletion tests/provider/test_in_memory_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ async def test_should_resolve_list_flag_from_in_memory():
)
# When
flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[])
flag_async = provider.resolve_object_details(flag_key="Key", default_value=[])
flag_async = await provider.resolve_object_details_async(
flag_key="Key", default_value=[]
)
# Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
Expand Down
196 changes: 196 additions & 0 deletions tests/provider/test_provider_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import asyncio
from typing import Optional, Union

import pytest

from openfeature.api import OpenFeatureClient, get_client, set_provider
from openfeature.evaluation_context import EvaluationContext
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.provider import AbstractProvider, Metadata


class SynchronousProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="SynchronousProvider")

def get_provider_hooks(self):
return []

def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)

def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return FlagResolutionDetails(value="string")

def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return FlagResolutionDetails(value=1)

def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return FlagResolutionDetails(value=10.0)

def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
return FlagResolutionDetails(value={"key": "value"})


@pytest.mark.parametrize(
"flag_type, default_value, get_method",
(
(bool, True, "get_boolean_value_async"),
(str, "string", "get_string_value_async"),
(int, 1, "get_integer_value_async"),
(float, 10.0, "get_float_value_async"),
(
dict,
{"key": "value"},
"get_object_value_async",
),
),
)
@pytest.mark.asyncio
async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method):
# Given
set_provider(SynchronousProvider(), "SynchronousProvider")
client = get_client("SynchronousProvider")
# When
async_callable = getattr(client, get_method)
flag = await async_callable(flag_key="Key", default_value=default_value)
# Then
assert flag is not None
assert flag == default_value
assert isinstance(flag, flag_type)


@pytest.mark.asyncio
async def test_sync_provider_can_be_extended_async():
# Given
class ExtendedAsyncProvider(SynchronousProvider):
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=False)

set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider")
client = get_client("ExtendedAsyncProvider")
# When
flag = await client.get_boolean_value_async(flag_key="Key", default_value=True)
# Then
assert flag is not None
assert flag is False


# We're not allowing providers to only have async methods
def test_sync_methods_enforced_for_async_providers():
# Given
class AsyncProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")

async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)

# When
with pytest.raises(TypeError) as exception:
set_provider(AsyncProvider(), "AsyncProvider")

# Then
# assert
assert str(exception.value).startswith(
"Can't instantiate abstract class AsyncProvider with abstract methods resolve_boolean_details"
)


@pytest.mark.asyncio
async def test_async_provider_not_implemented_exception_workaround():
# Given
class SyncNotImplementedProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")

async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)

def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
raise NotImplementedError("Use the async method")

def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
raise NotImplementedError("Use the async method")

def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
raise NotImplementedError("Use the async method")

def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
raise NotImplementedError("Use the async method")

def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
raise NotImplementedError("Use the async method")

# When
set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider")
client = get_client("SyncNotImplementedProvider")
flag = await client.get_boolean_value_async(flag_key="Key", default_value=False)
# Then
assert flag is not None
assert flag is True
Loading

0 comments on commit 86c64df

Please sign in to comment.