Skip to content

Commit b787d8d

Browse files
authored
Merge pull request #499 from aldbr/main_FEAT_jwks-rotation
feat: introduce jwks
2 parents 587205b + 3f7138b commit b787d8d

File tree

27 files changed

+740
-152
lines changed

27 files changed

+740
-152
lines changed

diracx-client/src/diracx/client/_generated/aio/operations/_operations.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
build_jobs_unassign_bulk_jobs_sandboxes_request,
5555
build_jobs_unassign_job_sandboxes_request,
5656
build_well_known_get_installation_metadata_request,
57+
build_well_known_get_jwks_request,
5758
build_well_known_get_openid_configuration_request,
5859
)
5960
from .._configuration import DiracConfiguration
@@ -146,6 +147,57 @@ async def get_openid_configuration(
146147

147148
return deserialized # type: ignore
148149

150+
@distributed_trace_async
151+
async def get_jwks(self, **kwargs: Any) -> Dict[str, Any]:
152+
"""Get Jwks.
153+
154+
Get the JWKs (public keys).
155+
156+
:return: dict mapping str to any
157+
:rtype: dict[str, any]
158+
:raises ~azure.core.exceptions.HttpResponseError:
159+
"""
160+
error_map: MutableMapping = {
161+
401: ClientAuthenticationError,
162+
404: ResourceNotFoundError,
163+
409: ResourceExistsError,
164+
304: ResourceNotModifiedError,
165+
}
166+
error_map.update(kwargs.pop("error_map", {}) or {})
167+
168+
_headers = kwargs.pop("headers", {}) or {}
169+
_params = kwargs.pop("params", {}) or {}
170+
171+
cls: ClsType[Dict[str, Any]] = kwargs.pop("cls", None)
172+
173+
_request = build_well_known_get_jwks_request(
174+
headers=_headers,
175+
params=_params,
176+
)
177+
_request.url = self._client.format_url(_request.url)
178+
179+
_stream = False
180+
pipeline_response: PipelineResponse = (
181+
await self._client._pipeline.run( # pylint: disable=protected-access
182+
_request, stream=_stream, **kwargs
183+
)
184+
)
185+
186+
response = pipeline_response.http_response
187+
188+
if response.status_code not in [200]:
189+
map_error(
190+
status_code=response.status_code, response=response, error_map=error_map
191+
)
192+
raise HttpResponseError(response=response)
193+
194+
deserialized = self._deserialize("{object}", pipeline_response.http_response)
195+
196+
if cls:
197+
return cls(pipeline_response, deserialized, {}) # type: ignore
198+
199+
return deserialized # type: ignore
200+
149201
@distributed_trace_async
150202
async def get_installation_metadata(self, **kwargs: Any) -> _models.Metadata:
151203
"""Get Installation Metadata.

diracx-client/src/diracx/client/_generated/models/_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ class OpenIDConfiguration(_serialization.Model):
563563
:vartype device_authorization_endpoint: str
564564
:ivar revocation_endpoint: Revocation Endpoint. Required.
565565
:vartype revocation_endpoint: str
566+
:ivar jwks_uri: Jwks Uri. Required.
567+
:vartype jwks_uri: str
566568
:ivar grant_types_supported: Grant Types Supported. Required.
567569
:vartype grant_types_supported: list[str]
568570
:ivar scopes_supported: Scopes Supported. Required.
@@ -585,6 +587,7 @@ class OpenIDConfiguration(_serialization.Model):
585587
"authorization_endpoint": {"required": True},
586588
"device_authorization_endpoint": {"required": True},
587589
"revocation_endpoint": {"required": True},
590+
"jwks_uri": {"required": True},
588591
"grant_types_supported": {"required": True},
589592
"scopes_supported": {"required": True},
590593
"response_types_supported": {"required": True},
@@ -603,6 +606,7 @@ class OpenIDConfiguration(_serialization.Model):
603606
"type": "str",
604607
},
605608
"revocation_endpoint": {"key": "revocation_endpoint", "type": "str"},
609+
"jwks_uri": {"key": "jwks_uri", "type": "str"},
606610
"grant_types_supported": {"key": "grant_types_supported", "type": "[str]"},
607611
"scopes_supported": {"key": "scopes_supported", "type": "[str]"},
608612
"response_types_supported": {
@@ -632,6 +636,7 @@ def __init__(
632636
authorization_endpoint: str,
633637
device_authorization_endpoint: str,
634638
revocation_endpoint: str,
639+
jwks_uri: str,
635640
grant_types_supported: List[str],
636641
scopes_supported: List[str],
637642
response_types_supported: List[str],
@@ -653,6 +658,8 @@ def __init__(
653658
:paramtype device_authorization_endpoint: str
654659
:keyword revocation_endpoint: Revocation Endpoint. Required.
655660
:paramtype revocation_endpoint: str
661+
:keyword jwks_uri: Jwks Uri. Required.
662+
:paramtype jwks_uri: str
656663
:keyword grant_types_supported: Grant Types Supported. Required.
657664
:paramtype grant_types_supported: list[str]
658665
:keyword scopes_supported: Scopes Supported. Required.
@@ -675,6 +682,7 @@ def __init__(
675682
self.authorization_endpoint = authorization_endpoint
676683
self.device_authorization_endpoint = device_authorization_endpoint
677684
self.revocation_endpoint = revocation_endpoint
685+
self.jwks_uri = jwks_uri
678686
self.grant_types_supported = grant_types_supported
679687
self.scopes_supported = scopes_supported
680688
self.response_types_supported = response_types_supported

diracx-client/src/diracx/client/_generated/operations/_operations.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ def build_well_known_get_openid_configuration_request(
5353
return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)
5454

5555

56+
def build_well_known_get_jwks_request(**kwargs: Any) -> HttpRequest:
57+
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
58+
59+
accept = _headers.pop("Accept", "application/json")
60+
61+
# Construct URL
62+
_url = "/.well-known/jwks.json"
63+
64+
# Construct headers
65+
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
66+
67+
return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)
68+
69+
5670
def build_well_known_get_installation_metadata_request(
5771
**kwargs: Any,
5872
) -> HttpRequest: # pylint: disable=name-too-long
@@ -746,6 +760,57 @@ def get_openid_configuration(self, **kwargs: Any) -> _models.OpenIDConfiguration
746760

747761
return deserialized # type: ignore
748762

763+
@distributed_trace
764+
def get_jwks(self, **kwargs: Any) -> Dict[str, Any]:
765+
"""Get Jwks.
766+
767+
Get the JWKs (public keys).
768+
769+
:return: dict mapping str to any
770+
:rtype: dict[str, any]
771+
:raises ~azure.core.exceptions.HttpResponseError:
772+
"""
773+
error_map: MutableMapping = {
774+
401: ClientAuthenticationError,
775+
404: ResourceNotFoundError,
776+
409: ResourceExistsError,
777+
304: ResourceNotModifiedError,
778+
}
779+
error_map.update(kwargs.pop("error_map", {}) or {})
780+
781+
_headers = kwargs.pop("headers", {}) or {}
782+
_params = kwargs.pop("params", {}) or {}
783+
784+
cls: ClsType[Dict[str, Any]] = kwargs.pop("cls", None)
785+
786+
_request = build_well_known_get_jwks_request(
787+
headers=_headers,
788+
params=_params,
789+
)
790+
_request.url = self._client.format_url(_request.url)
791+
792+
_stream = False
793+
pipeline_response: PipelineResponse = (
794+
self._client._pipeline.run( # pylint: disable=protected-access
795+
_request, stream=_stream, **kwargs
796+
)
797+
)
798+
799+
response = pipeline_response.http_response
800+
801+
if response.status_code not in [200]:
802+
map_error(
803+
status_code=response.status_code, response=response, error_map=error_map
804+
)
805+
raise HttpResponseError(response=response)
806+
807+
deserialized = self._deserialize("{object}", pipeline_response.http_response)
808+
809+
if cls:
810+
return cls(pipeline_response, deserialized, {}) # type: ignore
811+
812+
return deserialized # type: ignore
813+
749814
@distributed_trace
750815
def get_installation_metadata(self, **kwargs: Any) -> _models.Metadata:
751816
"""Get Installation Metadata.

diracx-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ classifiers = [
1414
]
1515
dependencies = [
1616
"aiobotocore>=2.15",
17-
"authlib",
1817
"botocore>=1.35",
1918
"cachetools",
2019
"email_validator",
2120
"gitpython",
21+
"joserfc",
2222
"pydantic >=2.10",
2323
"pydantic-settings",
2424
"pyyaml",

diracx-core/src/diracx/core/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class OpenIDConfiguration(TypedDict):
202202
authorization_endpoint: str
203203
device_authorization_endpoint: str
204204
revocation_endpoint: str
205+
jwks_uri: str
205206
grant_types_supported: list[str]
206207
scopes_supported: list[str]
207208
response_types_supported: list[str]

diracx-core/src/diracx/core/settings.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import json
6+
57
from diracx.core.properties import SecurityProperty
68
from diracx.core.s3 import s3_bucket_exists
79

@@ -17,11 +19,12 @@
1719
from typing import TYPE_CHECKING, Annotated, Any, Self, TypeVar
1820

1921
from aiobotocore.session import get_session
20-
from authlib.jose import JsonWebKey
2122
from botocore.config import Config
2223
from botocore.errorfactory import ClientError
2324
from cryptography.fernet import Fernet
25+
from joserfc.jwk import KeySet, RSAKey
2426
from pydantic import (
27+
AliasChoices,
2528
AnyUrl,
2629
BeforeValidator,
2730
Field,
@@ -51,28 +54,52 @@ class SqlalchemyDsn(AnyUrl):
5154
)
5255

5356

54-
class _TokenSigningKey(SecretStr):
55-
jwk: JsonWebKey
57+
class _TokenSigningKeyStore(SecretStr):
58+
jwks: KeySet
5659

5760
def __init__(self, data: str):
5861
super().__init__(data)
59-
self.jwk = JsonWebKey.import_key(self.get_secret_value())
60-
61-
62-
def _maybe_load_key_from_file(value: Any) -> Any:
63-
"""Load private keys from files if needed."""
64-
if isinstance(value, str) and not value.strip().startswith("-----BEGIN"):
65-
url = TypeAdapter(LocalFileUrl).validate_python(value)
66-
if not url.scheme == "file":
67-
raise ValueError("Only file:// URLs are supported")
68-
if url.path is None:
69-
raise ValueError("No path specified")
70-
value = Path(url.path).read_text()
62+
63+
# Load the keys from the JSON string
64+
try:
65+
keys = json.loads(self.get_secret_value())
66+
except json.JSONDecodeError as e:
67+
raise ValueError("Invalid JSON string") from e
68+
if not isinstance(keys, dict):
69+
raise ValueError("Invalid JSON string")
70+
self.jwks = KeySet.import_key_set(keys) # type: ignore
71+
72+
73+
def _maybe_load_keys_from_file(value: Any) -> Any:
74+
"""Load jwks from files if needed."""
75+
if isinstance(value, str):
76+
# If the value is a string, we need to check if it is a JSON string or a file URL
77+
if not (value.strip().startswith("{") or value.startswith("[")):
78+
# If it is not a JSON string, we assume it is a file URL
79+
url = TypeAdapter(LocalFileUrl).validate_python(value)
80+
if not url.scheme == "file":
81+
raise ValueError("Only file:// URLs are supported")
82+
if url.path is None:
83+
raise ValueError("No path specified")
84+
value = Path(url.path).read_text()
85+
86+
if isinstance(value, str) and value.strip().startswith("-----BEGIN"):
87+
return json.dumps(
88+
KeySet(
89+
keys=[
90+
RSAKey.import_key(
91+
value, # type: ignore
92+
parameters={"key_ops": ["sign", "verify"], "alg": "RS256"}, # type: ignore
93+
)
94+
]
95+
).as_dict(private=True)
96+
)
7197
return value
7298

7399

74-
TokenSigningKey = Annotated[
75-
_TokenSigningKey, BeforeValidator(_maybe_load_key_from_file)
100+
TokenSigningKeyStore = Annotated[
101+
_TokenSigningKeyStore,
102+
BeforeValidator(_maybe_load_keys_from_file),
76103
]
77104

78105

@@ -124,7 +151,9 @@ def create(cls) -> Self:
124151
class AuthSettings(ServiceSettingsBase):
125152
"""Settings for the authentication service."""
126153

127-
model_config = SettingsConfigDict(env_prefix="DIRACX_SERVICE_AUTH_")
154+
model_config = SettingsConfigDict(
155+
env_prefix="DIRACX_SERVICE_AUTH_", validate_by_name=True
156+
)
128157

129158
dirac_client_id: str = "myDIRACClientID"
130159
# TODO: This should be taken dynamically
@@ -137,8 +166,14 @@ class AuthSettings(ServiceSettingsBase):
137166
state_key: FernetKey
138167

139168
token_issuer: str
140-
token_key: TokenSigningKey
141-
token_algorithm: str = "RS256" # noqa: S105
169+
token_keystore: TokenSigningKeyStore = Field(
170+
validation_alias=AliasChoices(
171+
"token_keystore",
172+
"DIRACX_SERVICE_AUTH_TOKEN_KEYSTORE",
173+
"DIRACX_SERVICE_AUTH_TOKEN_KEY",
174+
)
175+
)
176+
token_allowed_algorithms: list[str] = ["RS256", "EdDSA"] # noqa: S105
142177
access_token_expire_minutes: int = 20
143178
refresh_token_expire_minutes: int = 60
144179

0 commit comments

Comments
 (0)