Skip to content

Commit

Permalink
Merge pull request #200 from permitio/omer/per-10806-get-user-tenants
Browse files Browse the repository at this point in the history
Add get user tenants external data store
  • Loading branch information
omer9564 authored Oct 27, 2024
2 parents 43f299b + 45e8bd5 commit e5c0128
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 41 deletions.
56 changes: 41 additions & 15 deletions horizon/enforcer/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import re
from typing import cast, Optional, Union, Dict, List
from typing import cast, Optional, Union, Dict, List, Callable

import aiohttp
from fastapi import APIRouter, Depends, Header
Expand Down Expand Up @@ -276,32 +276,50 @@ async def conditional_is_allowed(
*,
policy_package: str = MAIN_POLICY_PACKAGE,
external_data_manager_path: str = "/check",
external_data_manager_method: str = "POST",
external_data_manager_params: dict | None = None,
legacy_parse_func: Callable[[dict | list], dict] | None = None,
) -> dict:
if sidecar_config.ENABLE_EXTERNAL_DATA_MANAGER:
response = await _is_allowed_data_manager(query, request, path=external_data_manager_path)
response = await _is_allowed_data_manager(
query if external_data_manager_method != "GET" else None,
request,
path=external_data_manager_path,
method=external_data_manager_method,
params=external_data_manager_params,
)
raw_result = json.loads(response.body)
log_query_result(query, response, is_inner=True)
else:
response = await _is_allowed(query, request, policy_package)
raw_result = json.loads(response.body).get("result", {})
log_query_result(query, response)
if legacy_parse_func:
raw_result = legacy_parse_func(raw_result)
return raw_result


async def _is_allowed_data_manager(
query: BaseSchema, request: Request, *, path: str = "/check"
query: BaseSchema | None,
request: Request,
*,
path: str = "/check",
method: str = "POST",
params: dict | None = None,
):
headers = transform_headers(request)
url = f"{sidecar_config.DATA_MANAGER_SERVICE_URL}/v1/authz{path}"
payload = {"input": query.dict()}
payload = None if query is None else {"input": query.dict()}
exc = None
_set_use_debugger(payload)
try:
logger.info(f"calling Data Manager at '{url}' with input: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
async with session.request(
method,
url,
data=json.dumps(payload["input"]) if payload is not None else None,
params=params,
headers=headers,
timeout=sidecar_config.OPA_CLIENT_QUERY_TIMEOUT,
raise_for_status=True,
Expand Down Expand Up @@ -490,19 +508,27 @@ async def user_tenants(
query: UserTenantsQuery,
x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk),
):
response = await _is_allowed(query, request, USER_TENANTS_POLICY_PACKAGE)
log_query_result(query, response)
try:
raw_result = json.loads(response.body).get("result", {})
def parse_func(result: dict | list) -> dict | list:
if isinstance(raw_result, dict):
tenants = raw_result.get("tenants", {})
tenants = result.get("tenants", [])
elif isinstance(raw_result, list):
tenants = raw_result
tenants = result
else:
raise TypeError(
f"Expected raw result to be dict or list, got {type(raw_result)}"
)
result = parse_obj_as(UserTenantsResult, tenants)
return tenants

raw_result = await conditional_is_allowed(
query,
request,
policy_package=USER_TENANTS_POLICY_PACKAGE,
external_data_manager_path=f"/users/{query.user.key}/tenants",
external_data_manager_method="GET",
legacy_parse_func=parse_func,
)
try:
result = parse_obj_as(UserTenantsResult, raw_result)
except:
result = parse_obj_as(UserTenantsResult, [])
logger.warning(
Expand Down Expand Up @@ -607,8 +633,8 @@ async def is_allowed(
raise HTTPException(
status_code=status.HTTP_421_MISDIRECTED_REQUEST,
detail="Mismatch between client version and PDP version,"
" required v2 request body, got v1. "
"hint: try to update your client version to v2",
" required v2 request body, got v1. "
"hint: try to update your client version to v2",
)
query = cast(AuthorizationQuery, query)

Expand Down Expand Up @@ -689,7 +715,7 @@ async def is_allowed_kong(request: Request, query: KongAuthorizationQuery):
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Kong integration is disabled. "
"Please set the PDP_KONG_INTEGRATION variable to true to enable it.",
"Please set the PDP_KONG_INTEGRATION variable to true to enable it.",
)

await PersistentStateHandler.get_instance().seen_sdk("kong")
Expand Down
79 changes: 53 additions & 26 deletions horizon/tests/test_enforcer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ async def pdp_api_client() -> TestClient:
{"allow": [{"allow": True, "result": True}]},
{"allow": [{"allow": True, "result": True}]},
),
(
"/user-tenants",
"/users/user1/tenants",
UserTenantsQuery(
user=User(key="user1"),
),
None,
[{"key": "default-2", "attributes": {}}, {"key": "default", "attributes": {}}],
[{"key": "default-2", "attributes": {}}, {"key": "default", "attributes": {}}],
),
]


Expand Down Expand Up @@ -429,7 +439,14 @@ def post_endpoint():


@pytest.mark.parametrize(
("endpoint", "datasync_endpoint", "query", "headers", "datasync_response", "expected_response"),
(
"endpoint",
"datasync_endpoint",
"query",
"headers",
"datasync_response",
"expected_response",
),
ALLOWED_ENDPOINTS_DATASYNC,
)
def test_enforce_endpoint_datasync(
Expand All @@ -446,7 +463,8 @@ def test_enforce_endpoint_datasync(
def post_endpoint():
return _client.post(
endpoint,
headers={"authorization": f"Bearer {sidecar_config.API_KEY}"} | (headers or {}),
headers={"authorization": f"Bearer {sidecar_config.API_KEY}"}
| (headers or {}),
json=jsonable_encoder(query) if query else None,
)

Expand All @@ -455,33 +473,39 @@ def post_endpoint():
f"{sidecar_config.DATA_MANAGER_SERVICE_URL}/v1/authz{datasync_endpoint}"
)

method = "POST"

match endpoint:
case "/allowed_url":
# allowed_url gonna first call the mapping rules endpoint then the normal OPA allow endpoint
m.post(
url=f"{opal_client_config.POLICY_STORE_URL}/v1/data/mapping_rules",
status=200,
payload={
"result": {
"all": [
{
"url": "https://some.url/important_resource",
"http_method": "delete",
"action": "delete",
"resource": "resource1",
}
]
}
},
repeat=True,
)
case "/user-tenants":
method = "GET"

# Test valid response from OPA
m.post(
m.add(
datasync_url,
method=method,
status=200,
payload=datasync_response,
)

if endpoint == "/allowed_url":
# allowed_url gonna first call the mapping rules endpoint then the normal OPA allow endpoint
m.post(
url=f"{opal_client_config.POLICY_STORE_URL}/v1/data/mapping_rules",
status=200,
payload={
"result": {
"all": [
{
"url": "https://some.url/important_resource",
"http_method": "delete",
"action": "delete",
"resource": "resource1",
}
]
}
},
repeat=True,
)

response = post_endpoint()
assert response.status_code == 200
print(response.json())
Expand All @@ -499,8 +523,9 @@ def post_endpoint():

# Test bad status from OPA
bad_status = random.choice([401, 404, 400, 500, 503])
m.post(
m.add(
datasync_url,
method=method,
status=bad_status,
payload=datasync_response,
)
Expand All @@ -510,8 +535,9 @@ def post_endpoint():
assert f"status: {bad_status}" in response.text

# Test connection error
m.post(
m.add(
datasync_url,
method=method,
exception=aiohttp.ClientConnectionError("don't want to connect"),
)
response = post_endpoint()
Expand All @@ -520,8 +546,9 @@ def post_endpoint():
assert "don't want to connect" in response.text

# Test timeout - not working yet
m.post(
m.add(
datasync_url,
method=method,
exception=asyncio.exceptions.TimeoutError(),
)
response = post_endpoint()
Expand Down

0 comments on commit e5c0128

Please sign in to comment.