Skip to content

Commit

Permalink
feat(openfgaclient): add support for BatchCheck API
Browse files Browse the repository at this point in the history
  • Loading branch information
ewanharris committed Dec 12, 2024
1 parent 885c2de commit a59d9a1
Show file tree
Hide file tree
Showing 11 changed files with 1,283 additions and 70 deletions.
35 changes: 22 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -724,17 +724,19 @@ If 429s or 5xxs are encountered, the underlying check will retry up to 15 times

```python
# from openfga_sdk import OpenFgaClient
# from openfga_sdk.client import ClientCheckRequest
# from openfga_sdk.client.models import ClientTuple

# from openfga_sdk.client.models import (
# ClientTuple,
# ClientBatchCheckItem,
# ClientBatchCheckRequest,
# )
# Initialize the fga_client
# fga_client = OpenFgaClient(configuration)

options = {
# You can rely on the model id set in the configuration or override it for this specific request
"authorization_model_id": "01GXSA8YR785C4FYS3C0RTG7B1"
}
body = [ClientCheckRequest(
checks = [ClientBatchCheckItem(
user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
relation="viewer",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
Expand All @@ -748,7 +750,7 @@ body = [ClientCheckRequest(
context=dict(
ViewCount=100
)
), ClientCheckRequest(
), ClientBatchCheckItem(
user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
relation="admin",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
Expand All @@ -759,20 +761,21 @@ body = [ClientCheckRequest(
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
),
]
), ClientCheckRequest(
), ClientBatchCheckItem(
user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
relation="creator",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
), ClientCheckRequest(
), ClientBatchCheckItem(
user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
relation="deleter",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
)]

response = await fga_client.batch_check(body, options)
# response.responses = [{
response = await fga_client.batch_check(ClientBatchCheckRequest(checks=checks), options)
# response.result = [{
# allowed: false,
# request: {
# correlation_id: "de3630c2-f9be-4ee5-9441-cb1fbd82ce75",
# tuple: {
# user: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
# relation: "viewer",
# object: "document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
Expand All @@ -787,7 +790,8 @@ response = await fga_client.batch_check(body, options)
# }
# }, {
# allowed: false,
# request: {
# correlation_id: "6d7c7129-9607-480e-bfd0-17c16e46b9ec",
# tuple: {
# user: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
# relation: "admin",
# object: "document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
Expand All @@ -799,14 +803,19 @@ response = await fga_client.batch_check(body, options)
# }
# }, {
# allowed: false,
# request: {
# correlation_id: "210899b9-6bc3-4491-bdd1-d3d79780aa31",
# tuple: {
# user: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
# relation: "creator",
# object: "document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
# },
# error: <FgaError ...>
# error: {
# input_error: "validation_error",
# message: "relation 'document#creator' not found"
# }
# }, {
# allowed: true,
# correlation_id: "55cc1946-9fc3-4710-bd40-8fe2687ed8da",
# request: {
# user: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
# relation: "deleter",
Expand Down
29 changes: 29 additions & 0 deletions example/example1/example1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import uuid

from openfga_sdk import (
ClientConfiguration,
Expand All @@ -21,6 +22,8 @@
)
from openfga_sdk.client.models import (
ClientAssertion,
ClientBatchCheckItem,
ClientBatchCheckRequest,
ClientCheckRequest,
ClientListObjectsRequest,
ClientListRelationsRequest,
Expand Down Expand Up @@ -268,6 +271,32 @@ async def main():
)
print(f"Allowed: {response.allowed}")

# Performing a BatchCheck
print("Checking for access via BatchCheck")

anne_cor_id = str(uuid.uuid4())
response = await fga_client.batch_check(
ClientBatchCheckRequest(
checks=[
ClientBatchCheckItem(
user="user:anne",
relation="viewer",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
context=dict(ViewCount=100),
correlation_id=anne_cor_id,
),
ClientBatchCheckItem(
user="user:bob",
relation="viewer",
object="document:0192ab2a-d83f-756d-9397-c5ed9f3cb69a",
context=dict(ViewCount=100),
),
]
)
)
print(f"Anne allowed: {response.result[0].allowed}")
print(f"Bob allowed: {response.result[1].allowed}")

# List objects with context
print("Listing objects for access with context")

Expand Down
129 changes: 127 additions & 2 deletions openfga_sdk/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,24 @@
from openfga_sdk.api_client import ApiClient
from openfga_sdk.client.configuration import ClientConfiguration
from openfga_sdk.client.models.assertion import ClientAssertion
from openfga_sdk.client.models.batch_check_response import BatchCheckResponse
from openfga_sdk.client.models.batch_check_item import (
ClientBatchCheckItem,
construct_batch_item,
)
from openfga_sdk.client.models.batch_check_request import ClientBatchCheckRequest
from openfga_sdk.client.models.batch_check_response import (
ClientBatchCheckResponse,
)
from openfga_sdk.client.models.batch_check_single_response import (
ClientBatchCheckSingleResponse,
)
from openfga_sdk.client.models.check_request import (
ClientCheckRequest,
construct_check_request,
)
from openfga_sdk.client.models.client_batch_check_response import ClientBatchCheckClientResponse
from openfga_sdk.client.models.client_batch_check_response import (
ClientBatchCheckClientResponse,
)
from openfga_sdk.client.models.expand_request import ClientExpandRequest
from openfga_sdk.client.models.list_objects_request import ClientListObjectsRequest
from openfga_sdk.client.models.list_relations_request import ClientListRelationsRequest
Expand All @@ -41,6 +53,7 @@
UnauthorizedException,
)
from openfga_sdk.models.assertion import Assertion
from openfga_sdk.models.batch_check_request import BatchCheckRequest
from openfga_sdk.models.check_request import CheckRequest
from openfga_sdk.models.contextual_tuple_keys import ContextualTupleKeys
from openfga_sdk.models.create_store_request import CreateStoreRequest
Expand Down Expand Up @@ -636,6 +649,118 @@ async def client_batch_check(

return batch_check_response

async def _single_batch_check(
self,
body: BatchCheckRequest,
semaphore: asyncio.Semaphore,
options: dict[str, str] = None,
):
"""
Run a single BatchCheck request
:param body - list[ClientCheckRequest] defining check request
:param authorization_model_id(options) - Overrides the authorization model id in the configuration
"""
await semaphore.acquire()
try:
kwargs = options_to_kwargs(options)
api_response = await self._api.batch_check(body, **kwargs)
return api_response
except Exception as err:
raise err
finally:
semaphore.release()

async def batch_check(self, body: ClientBatchCheckRequest, options=None):
"""
Run a batchcheck request
:param body - BatchCheck request
:param authorization_model_id(options) - Overrides the authorization model id in the configuration
:param max_parallel_requests(options) - Max number of requests to issue in parallel. Defaults to 10
:param max_batch_size(options) - Max number of checks to include in a request. Defaults to 50
:param header(options) - Custom headers to send alongside the request
:param retryParams(options) - Override the retry parameters for this request
:param retryParams.maxRetry(options) - Override the max number of retries on each API request
:param retryParams.minWaitInMs(options) - Override the minimum wait before a retry is initiated
"""
options = set_heading_if_not_set(
options, CLIENT_BULK_REQUEST_ID_HEADER, str(uuid.uuid4())
)

max_parallel_requests = 10
if options is not None and "max_parallel_requests" in options:
if (
isinstance(options["max_parallel_requests"], str)
and options["max_parallel_requests"].isdigit()
):
max_parallel_requests = int(options["max_parallel_requests"])
elif isinstance(options["max_parallel_requests"], int):
max_parallel_requests = options["max_parallel_requests"]

max_batch_size = 50
if options is not None and "max_batch_size" in options:
if (
isinstance(options["max_batch_size"], str)
and options["max_batch_size"].isdigit()
):
max_batch_size = int(options["max_batch_size"])
elif isinstance(options["max_batch_size"], int):
max_batch_size = options["max_batch_size"]

check_to_id: dict[str, ClientBatchCheckItem] = {}

def track_and_transform(checks):
transformed = []
for check in checks:
if check.correlation_id is None:
check.correlation_id = str(uuid.uuid4())

if check.correlation_id in check_to_id:
raise FgaValidationException("Duplicate correlation_id provided")

check_to_id[check.correlation_id] = check

transformed.append(construct_batch_item(check))
return transformed

checks = [
track_and_transform(
body.checks[i * max_batch_size : (i + 1) * max_batch_size]
)
for i in range((len(body.checks) + max_batch_size - 1) // max_batch_size)
]

result = []
sem = asyncio.Semaphore(max_parallel_requests)

def map_response(id, result):
check = check_to_id[id]
return ClientBatchCheckSingleResponse(
allowed=result.allowed,
tuple=check,
correlation_id=id,
error=result.error,
)

async def coro(checks):
res = await self._single_batch_check(
BatchCheckRequest(
checks=checks,
authorization_model_id=self._get_authorization_model_id(options),
consistency=self._get_consistency(options),
),
sem,
options,
)

result.extend(
[map_response(c_id, c_result) for c_id, c_result in res.result.items()]
)

batch_check_coros = [coro(request) for request in checks]
await asyncio.gather(*batch_check_coros)

return ClientBatchCheckResponse(result)

async def expand(self, body: ClientExpandRequest, options: dict[str, str] = None):
"""
Run expand request
Expand Down
11 changes: 9 additions & 2 deletions openfga_sdk/client/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
"""

from openfga_sdk.client.models.assertion import ClientAssertion
from openfga_sdk.client.models.batch_check_response import BatchCheckResponse
from openfga_sdk.client.models.batch_check_item import ClientBatchCheckItem
from openfga_sdk.client.models.batch_check_request import ClientBatchCheckRequest
from openfga_sdk.client.models.batch_check_response import ClientBatchCheckResponse
from openfga_sdk.client.models.batch_check_single_response import (
ClientBatchCheckSingleResponse,
)
from openfga_sdk.client.models.check_request import ClientCheckRequest
from openfga_sdk.client.models.client_batch_check_response import ClientBatchCheckClientResponse
from openfga_sdk.client.models.client_batch_check_response import (
ClientBatchCheckClientResponse,
)
from openfga_sdk.client.models.expand_request import ClientExpandRequest
from openfga_sdk.client.models.list_objects_request import ClientListObjectsRequest
from openfga_sdk.client.models.list_relations_request import ClientListRelationsRequest
Expand Down
Loading

0 comments on commit a59d9a1

Please sign in to comment.