Skip to content

Commit

Permalink
Enable OAuth2 authentication.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ondrej Scecina committed Nov 19, 2024
1 parent 7a2a713 commit 89e2d99
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 31 deletions.
4 changes: 3 additions & 1 deletion packages/opal-client/opal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def _configure_api_routes(self, app: FastAPI):
policy_router = init_policy_router(policy_updater=self.policy_updater)
data_router = init_data_router(data_updater=self.data_updater)
policy_store_router = init_policy_store_router(self.authenticator)
callbacks_router = init_callbacks_api(self.authenticator, self._callbacks_register)
callbacks_router = init_callbacks_api(
self.authenticator, self._callbacks_register
)

# mount the api routes on the app object
app.include_router(policy_router, tags=["Policy Updater"])
Expand Down
27 changes: 19 additions & 8 deletions packages/opal-client/opal_client/data/oauth2_updater.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,46 @@
from urllib.parse import parse_qs, urlencode, urlparse

import aiohttp
from aiohttp.client import ClientSession
from opal_client.logger import logger
from urllib.parse import urlencode, urlparse, parse_qs

from .updater import DefaultDataUpdater


class OAuth2DataUpdater(DefaultDataUpdater):
async def _load_policy_data_config(self, url: str, headers) -> aiohttp.ClientResponse:
async def _load_policy_data_config(
self, url: str, headers
) -> aiohttp.ClientResponse:
await self._authenticator.authenticate(headers)

async with ClientSession(headers=headers) as session:
response = await session.get(url, **self._ssl_context_kwargs, allow_redirects=False)
response = await session.get(
url, **self._ssl_context_kwargs, allow_redirects=False
)

if response.status == 307:
return await self._load_redirected_policy_data_config(response.headers['location'], headers)
return await self._load_redirected_policy_data_config(
response.headers['location'], headers
)
else:
return response

async def _load_redirected_policy_data_config(self, url: str, headers):
redirect_url = self.__redirect_url(url)

logger.info("Redirecting to data-sources configuration '{source}'", source=redirect_url)
logger.info(
"Redirecting to data-sources configuration '{source}'", source=redirect_url
)

async with ClientSession(headers=headers) as session:
return await session.get(redirect_url, **self._ssl_context_kwargs, allow_redirects=False)
return await session.get(
redirect_url, **self._ssl_context_kwargs, allow_redirects=False
)

def __redirect_url(self, url: str) -> str:
u = urlparse(url)
query = parse_qs(u.query, keep_blank_values=True)
query.pop('token', None)
query.pop("token", None)
u = u._replace(query=urlencode(query, True))

return u.geturl()
return u.geturl()
10 changes: 7 additions & 3 deletions packages/opal-client/opal_client/data/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def get_policy_data_config(self, url: str = None) -> DataSourceConfig:
headers = {}
if self._extra_headers is not None:
headers = self._extra_headers.copy()
headers['Accept'] = "application/json"
headers["Accept"] = "application/json"

try:
response = await self._load_policy_data_config(url, headers)
Expand All @@ -257,7 +257,9 @@ async def get_policy_data_config(self, url: str = None) -> DataSourceConfig:
logger.exception(f"Failed to load data sources config")
raise

async def _load_policy_data_config(self, url: str, headers) -> aiohttp.ClientResponse:
async def _load_policy_data_config(
self, url: str, headers
) -> aiohttp.ClientResponse:
async with ClientSession(headers=headers) as session:
return await session.get(url, **self._ssl_context_kwargs)

Expand Down Expand Up @@ -527,7 +529,9 @@ async def _store_fetched_update(self, update_item):
policy_data = result
# Create a report on the data-fetching
report = DataEntryReport(
entry=entry, hash=DataUpdater.calc_hash(policy_data), fetched=True
entry=entry,
hash=DataUpdater.calc_hash(policy_data),
fetched=True
)

try:
Expand Down
7 changes: 6 additions & 1 deletion packages/opal-client/opal_client/tests/data_updater_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

from opal_client.config import opal_client_config
from opal_client.data.rpc import TenantAwareRpcEventClientMethods
from opal_client.data.updater import DataSourceEntry, DataUpdate, DataUpdater, DefaultDataUpdater
from opal_client.data.updater import (
DataSourceEntry,
DataUpdate,
DataUpdater,
DefaultDataUpdater,
)
from opal_client.policy_store.policy_store_client_factory import (
PolicyStoreClientFactory,
)
Expand Down
12 changes: 8 additions & 4 deletions packages/opal-common/opal_common/authentication/jwk.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import jwt
import httpx

import jwt
from cachetools import TTLCache
from opal_common.authentication.verifier import Unauthorized

class JWKManager:
def __init__(self, openid_configuration_url, jwt_algorithm, cache_maxsize, cache_ttl):
def __init__(
self, openid_configuration_url, jwt_algorithm, cache_maxsize, cache_ttl
):
self._openid_configuration_url = openid_configuration_url
self._jwt_algorithm = jwt_algorithm
self._cache = TTLCache(maxsize=cache_maxsize, ttl=cache_ttl)

def public_key(self, token):
header = jwt.get_unverified_header(token)
kid = header['kid']
kid = header["kid"]

public_key = self._cache.get(kid)
if public_key is None:
Expand Down Expand Up @@ -40,6 +42,8 @@ def _openid_configuration(self):
response = httpx.get(self._openid_configuration_url)

if response.status_code != httpx.codes.OK:
raise Unauthorized(description=f"invalid status code {response.status_code}")
raise Unauthorized(
description=f"invalid status code {response.status_code}"
)

return response.json()
28 changes: 16 additions & 12 deletions packages/opal-common/opal_common/authentication/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import httpx
import time
from typing import Optional

from cachetools import cached, TTLCache
import httpx
from cachetools import TTLCache, cached
from fastapi import Header
from httpx import AsyncClient, BasicAuth
from opal_common.authentication.authenticator import Authenticator
Expand All @@ -11,13 +12,12 @@
from opal_common.authentication.signer import JWTSigner
from opal_common.authentication.verifier import JWTVerifier, Unauthorized
from opal_common.config import opal_common_config
from typing import Optional

class _OAuth2Authenticator(Authenticator):
async def authenticate(self, headers):
if "Authorization" not in headers:
token = await self.token()
headers['Authorization'] = f"Bearer {token}"
headers["Authorization"] = f"Bearer {token}"


class OAuth2ClientCredentialsAuthenticator(_OAuth2Authenticator):
Expand Down Expand Up @@ -61,7 +61,7 @@ async def token(self):

async with AsyncClient() as client:
response = await client.post(self._token_url, auth=auth, data=data)
return (response.json())['access_token']
return (response.json())["access_token"]

def __call__(self, authorization: Optional[str] = Header(None)) -> {}:
token = get_token_from_header(authorization)
Expand All @@ -79,10 +79,12 @@ def verify(self, token: str) -> {}:
return claims

def _verify_opaque(self, token: str) -> {}:
response = httpx.post(self._introspect_url, data={'token': token})
response = httpx.post(self._introspect_url, data={"token": token})

if response.status_code != httpx.codes.OK:
raise Unauthorized(description=f"invalid status code {response.status_code}")
raise Unauthorized(
description=f"invalid status code {response.status_code}"
)

claims = response.json()
active = claims.get("active", False)
Expand Down Expand Up @@ -152,13 +154,15 @@ async def token(self):
claims = self._delegate.verify(token)

self._token = token
self._exp = claims['exp']
self._exp = claims["exp"]

return self._token

@cached(cache=TTLCache(
maxsize=opal_common_config.OAUTH2_TOKEN_VERIFY_CACHE_MAXSIZE,
ttl=opal_common_config.OAUTH2_TOKEN_VERIFY_CACHE_TTL
))
@cached(
cache=TTLCache(
maxsize=opal_common_config.OAUTH2_TOKEN_VERIFY_CACHE_MAXSIZE,
ttl=opal_common_config.OAUTH2_TOKEN_VERIFY_CACHE_TTL
)
)
def __call__(self, authorization: Optional[str] = Header(None)) -> {}:
return self._delegate(authorization)
2 changes: 1 addition & 1 deletion packages/opal-server/opal_server/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from fastapi import APIRouter, Depends, Header, HTTPException, status
from fastapi.responses import RedirectResponse
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.authz import (
require_peer_type,
restrict_optional_topics_to_publish,
)
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.deps import get_token_from_header
from opal_common.authentication.types import JWTClaims
from opal_common.authentication.verifier import Unauthorized
Expand Down
1 change: 0 additions & 1 deletion packages/opal-server/opal_server/scopes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
require_peer_type,
restrict_optional_topics_to_publish,
)
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.casting import cast_private_key
from opal_common.authentication.deps import get_token_from_header
from opal_common.authentication.types import EncryptionKeyFormat, JWTClaims
Expand Down

0 comments on commit 89e2d99

Please sign in to comment.