Skip to content

Commit

Permalink
feat: support auto-retry of failed network requests to /oauth/token
Browse files Browse the repository at this point in the history
… endpoint (#79)
  • Loading branch information
evansims authored Apr 1, 2024
2 parents f337b94 + dea01cc commit 09bce8e
Show file tree
Hide file tree
Showing 9 changed files with 439 additions and 47 deletions.
1 change: 0 additions & 1 deletion openfga_sdk/api/open_fga_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT.
"""


from openfga_sdk.api_client import ApiClient
from openfga_sdk.exceptions import ApiValueError, FgaValidationException
from openfga_sdk.oauth2 import OAuth2Client
Expand Down
85 changes: 69 additions & 16 deletions openfga_sdk/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,47 @@
NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT.
"""

import asyncio
import json
import math
import random
import sys
from datetime import datetime, timedelta

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError


def jitter(loop_count, min_wait_in_ms):
"""
Generate a random jitter value for exponential backoff
"""
minimum = math.ceil(2**loop_count * min_wait_in_ms)
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
jitter = random.randrange(minimum, maximum) / 1000

# If running in pytest, set jitter to 0 to speed up tests
if "pytest" in sys.modules:
jitter = 0

return jitter


class OAuth2Client:

def __init__(self, credentials: Credentials):
def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None

if configuration is None:
configuration = Configuration.get_default_copy()

self.configuration = configuration

def _token_valid(self):
"""
Return whether token is valid
Expand All @@ -41,37 +66,65 @@ async def _obtain_token(self, client):
Perform OAuth2 and obtain token
"""
configuration = self._credentials.configuration

token_url = f"https://{configuration.api_issuer}/oauth/token"

post_params = {
"client_id": configuration.client_id,
"client_secret": configuration.client_secret,
"audience": configuration.api_audience,
"grant_type": "client_credentials",
}

headers = urllib3.response.HTTPHeaderDict(
{
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": "openfga-sdk (python) 0.4.1",
}
)
raw_response = await client.POST(
token_url, headers=headers, post_params=post_params

max_retry = (
self.configuration.retry_params.max_retry
if (
self.configuration.retry_params is not None
and self.configuration.retry_params.max_retry is not None
)
else 0
)
if 200 <= raw_response.status <= 299:
try:
api_response = json.loads(raw_response.data)
except:
raise AuthenticationError(http_resp=raw_response)
if not api_response.get("expires_in") or not api_response.get(
"access_token"
):
raise AuthenticationError(http_resp=raw_response)
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))

min_wait_in_ms = (
self.configuration.retry_params.min_wait_in_ms
if (
self.configuration.retry_params is not None
and self.configuration.retry_params.min_wait_in_ms is not None
)
else 0
)

for attempt in range(max_retry + 1):
raw_response = await client.POST(
token_url, headers=headers, post_params=post_params
)
self._access_token = api_response.get("access_token")
else:

if 500 <= raw_response.status <= 599 or raw_response.status == 429:
if attempt < max_retry and raw_response.status != 501:
await asyncio.sleep(jitter(attempt, min_wait_in_ms))
continue

if 200 <= raw_response.status <= 299:
try:
api_response = json.loads(raw_response.data)
except:
raise AuthenticationError(http_resp=raw_response)

if api_response.get("expires_in") and api_response.get("access_token"):
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))
)
self._access_token = api_response.get("access_token")
break

raise AuthenticationError(http_resp=raw_response)

async def get_authentication_header(self, client):
Expand Down
85 changes: 70 additions & 15 deletions openfga_sdk/sync/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,46 @@
"""

import json
import math
import random
import sys
import time
from datetime import datetime, timedelta

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError


def jitter(loop_count, min_wait_in_ms):
"""
Generate a random jitter value for exponential backoff
"""
minimum = math.ceil(2**loop_count * min_wait_in_ms)
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
jitter = random.randrange(minimum, maximum) / 1000

# If running in pytest, set jitter to 0 to speed up tests
if "pytest" in sys.modules:
jitter = 0

return jitter


class OAuth2Client:

def __init__(self, credentials: Credentials):
def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None

if configuration is None:
configuration = Configuration.get_default_copy()

self.configuration = configuration

def _token_valid(self):
"""
Return whether token is valid
Expand All @@ -41,35 +66,65 @@ def _obtain_token(self, client):
Perform OAuth2 and obtain token
"""
configuration = self._credentials.configuration

token_url = f"https://{configuration.api_issuer}/oauth/token"

post_params = {
"client_id": configuration.client_id,
"client_secret": configuration.client_secret,
"audience": configuration.api_audience,
"grant_type": "client_credentials",
}

headers = urllib3.response.HTTPHeaderDict(
{
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": "openfga-sdk (python) 0.4.1",
}
)
raw_response = client.POST(token_url, headers=headers, post_params=post_params)
if 200 <= raw_response.status <= 299:
try:
api_response = json.loads(raw_response.data)
except:
raise AuthenticationError(http_resp=raw_response)
if not api_response.get("expires_in") or not api_response.get(
"access_token"
):
raise AuthenticationError(http_resp=raw_response)
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))

max_retry = (
self.configuration.retry_params.max_retry
if (
self.configuration.retry_params is not None
and self.configuration.retry_params.max_retry is not None
)
else 0
)

min_wait_in_ms = (
self.configuration.retry_params.min_wait_in_ms
if (
self.configuration.retry_params is not None
and self.configuration.retry_params.min_wait_in_ms is not None
)
else 0
)

for attempt in range(max_retry + 1):
raw_response = client.POST(
token_url, headers=headers, post_params=post_params
)
self._access_token = api_response.get("access_token")
else:

if 500 <= raw_response.status <= 599 or raw_response.status == 429:
if attempt < max_retry and raw_response.status != 501:
time.sleep(jitter(attempt, min_wait_in_ms))
continue

if 200 <= raw_response.status <= 299:
try:
api_response = json.loads(raw_response.data)
except:
raise AuthenticationError(http_resp=raw_response)

if api_response.get("expires_in") and api_response.get("access_token"):
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))
)
self._access_token = api_response.get("access_token")
break

raise AuthenticationError(http_resp=raw_response)

def get_authentication_header(self, client):
Expand Down
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

mock >= 5.1.0, < 6
flake8 >= 7.0.0, < 8
pytest-cov >= 4.1.0, < 5
pytest-cov >= 5, < 6
griffe >= 0.41.2, < 1
8 changes: 4 additions & 4 deletions test/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_configuration_client_credentials(self):
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="www.testme.com",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
Expand All @@ -121,7 +121,7 @@ def test_configuration_client_credentials_missing_client_id(self):
method="client_credentials",
configuration=CredentialConfiguration(
client_secret="mysecret",
api_issuer="www.testme.com",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
Expand All @@ -136,7 +136,7 @@ def test_configuration_client_credentials_missing_client_secret(self):
method="client_credentials",
configuration=CredentialConfiguration(
client_id="myclientid",
api_issuer="www.testme.com",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_configuration_client_credentials_missing_api_audience(self):
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="www.testme.com",
api_issuer="issuer.fga.example",
),
)
with self.assertRaises(openfga_sdk.ApiValueError):
Expand Down
Loading

0 comments on commit 09bce8e

Please sign in to comment.