Skip to content

Commit

Permalink
Refactor jwks fetch (#65)
Browse files Browse the repository at this point in the history
* #2

* #1

* #3

* fix http headers

* optional zid

* version++

* more strict replacing

* Align version with other libs

---------

Co-authored-by: Allen Liu <[email protected]>
Co-authored-by: robertofalk <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2023
1 parent d90c9e0 commit 0c55482
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 48 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@

# Change Log
All notable changes to this project will be documented in this file.

This project adheres to [Semantic Versioning](http://semver.org/).

The format is based on [Keep a Changelog](http://keepachangelog.com/).

## 4.1.0
### Changed
- Removed JKU validation for XSUAA tokens and replaced it with composing JKU using UAA Domain.
- Added extra HTTP headers for improved IAS verification key retrieval.
- Implemented more strict issuer validation for IAS tokens.

## 4.0.1
### Fixed
- Bug: fix `aud` validation for IAS tokens
Expand Down
19 changes: 13 additions & 6 deletions sap/xssec/key_cache_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import defaultdict
from functools import _make_key # noqa
from threading import Lock
from typing import Optional, List, Dict, Any
from typing import List, Dict, Any

import httpx
from cachetools import cached, TTLCache
Expand Down Expand Up @@ -42,12 +42,18 @@ def _fetch_verification_key_url_ias(issuer_url: str) -> str:
return resp.json()["jwks_uri"]


def _download_verification_key_ias(verification_key_url: str, zone_id: Optional[str]) -> List[Dict[str, Any]]:
def _download_verification_key_ias(verification_key_url: str, app_tid: str, azp: str,
client_id: str) -> List[Dict[str, Any]]:
"""
get all the keys from verification key url
"""
default_headers = {'Accept': 'application/json'}
headers = default_headers if zone_id is None else {**default_headers, "x-zone_uuid": zone_id}
headers = {
'x-app_tid': app_tid,
'x-azp': azp,
'x-client_id': client_id,
'Accept': 'application/json',
}
headers = {k: v for k, v in headers.items() if v is not None}
resp = httpx.get(verification_key_url, headers=headers, timeout=HTTP_TIMEOUT_IN_SECONDS)
resp.raise_for_status()
return resp.json()["keys"]
Expand All @@ -59,12 +65,13 @@ def _download_verification_key_ias(verification_key_url: str, zone_id: Optional[

@thread_safe_by_args
@cached(cache=key_cache)
def get_verification_key_ias(issuer_url: str, zone_id: Optional[str], kid: str) -> str:
def get_verification_key_ias(issuer_url: str, app_tid: str, azp: str, client_id: str, kid: str) -> str:
"""
get verification key for ias
"""
verification_key_url: str = _fetch_verification_key_url_ias(issuer_url)
verification_key_list: List[Dict[str, Any]] = _download_verification_key_ias(verification_key_url, zone_id)
verification_key_list: List[Dict[str, Any]] = _download_verification_key_ias(verification_key_url, app_tid, azp,
client_id)
found = list(filter(lambda k: k["kid"] == kid, verification_key_list))
if len(found) == 0:
raise ValueError("Could not find key with kid {}".format(kid))
Expand Down
15 changes: 13 additions & 2 deletions sap/xssec/security_context_ias.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Security Context class for IAS support"""
import logging
import re
from typing import List, Dict
from urllib3.util import Url, parse_url # type: ignore
from sap.xssec.jwt_audience_validator import JwtAudienceValidator
Expand Down Expand Up @@ -45,7 +46,12 @@ def validate_issuer(self):

domains: List[str] = self.service_credentials.get("domains") or (
[self.service_credentials["domain"]] if "domain" in self.service_credentials else [])
if not any(map(lambda d: issuer_url.host.endswith(d), domains)):

def validate_issuer_subdomain(parent_domain) -> bool:
pattern = r'^https://[a-zA-Z0-9-]{{1,63}}\.{parent_domain}$'.format(parent_domain=re.escape(parent_domain))
return bool(re.match(pattern, self.get_issuer()))

if not any(map(validate_issuer_subdomain, domains)):
raise ValueError("Token's issuer is not found in domain list {}".format(", ".join(domains)))

return self
Expand Down Expand Up @@ -76,7 +82,12 @@ def validate_signature(self):
check signature in jwt token
"""
verification_key: str = get_verification_key_ias(
self.get_issuer(), self.token_payload.get("zone_uuid"), self.token_header["kid"])
issuer_url=self.get_issuer(),
app_tid=self.token_payload.get("app_tid") or self.token_payload.get("zone_uuid"),
azp=self.token_payload.get("azp"),
client_id=self.service_credentials["clientid"],
kid=self.token_header["kid"],
)

result_code = self.jwt_validator.loadPEM(verification_key)
if result_code != 0:
Expand Down
27 changes: 11 additions & 16 deletions sap/xssec/security_context_xsuaa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=too-many-public-methods
""" Security Context class """
import functools
import re
import tempfile
from os import environ, unlink
import json
Expand All @@ -10,7 +11,6 @@

import httpx
import deprecation
import urllib3

from sap.xssec import constants
from sap.xssec.jwt_validation_facade import JwtValidationFacade, DecodeError
Expand Down Expand Up @@ -118,17 +118,14 @@ def _init_xsappname(self):
' the manifest.yml (legacy) as well as in xs-security.json.'
' Remove it in manifest.yml.')

def _validate_jku(self):
# configured uaa domain must be part of jku in order to trust jku
uaa_domain = self._config['uaadomain']

def _get_jku(self):
uaa_domain: str = self._config.get('uaadomain') or self._config.get('url')
if not uaa_domain:
raise RuntimeError("Service is not properly configured in 'VCAP_SERVICES'")

jku_url = urllib3.util.parse_url(self._properties['jku'])
if not jku_url.hostname.endswith(uaa_domain):
self._logger.error("Error: Do not trust jku '{}' because it does not match uaa domain".format(self._properties['jku']))
raise RuntimeError("JKU of token is not trusted")
uaa_domain = re.sub(r'^https://', '', uaa_domain)
payload = self._jwt_validator.decode(self._token, verify=False)
zid = payload.get("zid")
return f"https://{uaa_domain}/token_keys?zid={zid}" if zid else f"https://{uaa_domain}/token_keys"

def _set_token_properties(self):

Expand Down Expand Up @@ -338,13 +335,11 @@ def _set_scopes(self, jwt_payload):
self._logger.debug('Obtained scopes: %s.', self._properties['scopes'])

def _validate_token(self):
""" Try to retrieve the key from the uaa if jku and kid is set. Otherwise use configured one."""

if "uaadomain" in self._config and self._properties['jku'] and self._properties['kid']:
self._validate_jku()
""" Try to retrieve the key from the composed jku if kid is set. Otherwise use configured one."""
if self._properties['kid']:
try:
verification_key = SecurityContextXSUAA.verificationKeyCache.load_key(self._properties['jku'],
self._properties['kid'])
jku = self._get_jku()
verification_key = SecurityContextXSUAA.verificationKeyCache.load_key(jku, self._properties['kid'])
return self._get_jwt_payload(verification_key)
except (DecodeError, RuntimeError, IOError) as e:
self._logger.warning("Warning: Could not validate key: {} Will retry with configured key.".format(e))
Expand Down
14 changes: 13 additions & 1 deletion tests/ias/ias_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,17 @@
"-1opcRmlX1x8zgi2l-XxzSKrLABz0Fq2GJGZmD1PU_"
"-W6FHzE7ocokfYSViJ1_mBGn5KJwUIC2vBO9jWquGlM9TkdPP5DpmONEO5yFu6aO6GeEF3k9hOEL0AS0GOm8KmywhDg"
"-s5FGVNuwNG0O_nQn3VI9jigXuKuz5_e1becT2rw88fpizFG476TwB6BQCk8SWc "
}, {
"kty": "RSA",
"e": "AQAB",
"use": "sig",
"kid": "another-kid",
"alg": "RS256",
"value": "public key here",
"n": "AMGmSCHT8kourWCKVwTQKKr7a_rs8AGiwVPmeycKq2Mja5P3YXMDMOO7Qb9"
"-v5YNv0dkD7eu9v4AzilpcnnGASQbewNbaz2wJWMwIvjxG7VcHjqcf-oF9bfHv8nR1TTp52OwSKaKqunMtIrS1uJ"
"-1opcRmlX1x8zgi2l-XxzSKrLABz0Fq2GJGZmD1PU_"
"-W6FHzE7ocokfYSViJ1_mBGn5KJwUIC2vBO9jWquGlM9TkdPP5DpmONEO5yFu6aO6GeEF3k9hOEL0AS0GOm8KmywhDg"
"-s5FGVNuwNG0O_nQn3VI9jigXuKuz5_e1becT2rw88fpizFG476TwB6BQCk8SWc "
}]
}
}
1 change: 1 addition & 0 deletions tests/ias/ias_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def merge(dict1, dict2):
"user_uuid": "db60e49c-1fb7-4a15-9a9e-8ababf856fe9",
"azp": "70af88d4-0371-4374-b4f5-f24f650bfac5",
"zone_uuid": "4b0c2b7a-1279-4352-a68d-a9a228a4f1e9",
"app_tid": "4b0c2b7a-1279-4352-a68d-a9a228a4f1e9",
"iat": 1470815434,
"exp": 2101535434,
"family_name": "Nachname",
Expand Down
10 changes: 9 additions & 1 deletion tests/ias/test_xssec_ias.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
reload = None
from mock import MagicMock, patch

VERIFICATION_KEY_PARAMS = {
"issuer_url": PAYLOAD["iss"],
"app_tid": PAYLOAD["app_tid"] or PAYLOAD["zone_uuid"],
"azp": PAYLOAD["azp"],
"client_id": SERVICE_CREDENTIALS["clientid"],
"kid": HEADER["kid"]
}


class IASXSSECTest(unittest.TestCase):

Expand All @@ -27,7 +35,7 @@ def setUp(self):
@patch('sap.xssec.security_context_ias.get_verification_key_ias', return_value=JWT_SIGNING_PUBLIC_KEY)
def test_input_validation_valid_token(self, get_verification_key_ias_mock):
xssec.create_security_context_ias(VALID_TOKEN, ias_configs.SERVICE_CREDENTIALS)
get_verification_key_ias_mock.assert_called_with(PAYLOAD["iss"], PAYLOAD["zone_uuid"], HEADER["kid"])
get_verification_key_ias_mock.assert_called_with(**VERIFICATION_KEY_PARAMS)

def test_input_validation_invalid_token(self):
with self.assertRaises(ValueError) as ctx:
Expand Down
61 changes: 50 additions & 11 deletions tests/test_key_cache_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from typing import List, Callable

import pytest
from httpx import Response, HTTPStatusError
from httpx import Response, HTTPStatusError, Request

from sap.xssec.key_tools import jwk_to_pem
from tests.ias.ias_configs import JWKS, WELL_KNOWN
from tests.ias.ias_tokens import PAYLOAD, HEADER
from tests.ias.ias_configs import JWKS, WELL_KNOWN, SERVICE_CREDENTIALS
from tests.ias.ias_tokens import PAYLOAD, HEADER, merge

VERIFICATION_KEY_PARAMS = {
"issuer_url": PAYLOAD["iss"],
"app_tid": PAYLOAD["app_tid"] or PAYLOAD["zone_uuid"],
"azp": PAYLOAD["azp"],
"client_id": SERVICE_CREDENTIALS["clientid"],
"kid": HEADER["kid"]
}


def test_thread_safe_decorator():
Expand All @@ -30,13 +38,13 @@ def run_func_in_threads(func: Callable[[int], None], func_args: List[int]):

# not thread-safe without decorator
sum = 0
run_func_in_threads(add_to_sum, [1]*10)
run_func_in_threads(add_to_sum, [1] * 10)
assert 10 != sum

# thread-safe when args are same
sum = 0
from sap.xssec.key_cache_v2 import thread_safe_by_args
run_func_in_threads(thread_safe_by_args(add_to_sum), [1]*10)
run_func_in_threads(thread_safe_by_args(add_to_sum), [1] * 10)
assert 10 == sum

# not thread-safe when args are different
Expand All @@ -51,15 +59,22 @@ def well_known_endpoint_mock(respx_mock):
return_value=Response(200, json=WELL_KNOWN))


def jwk_endpoint_response(request: Request):
if all(k in request.headers for k in ("x-app_tid", "x-azp", "x-client_id")):
return Response(200, json=JWKS)
else:
return Response(404)


@pytest.fixture
def jwk_endpoint_mock(respx_mock):
return respx_mock.get(WELL_KNOWN["jwks_uri"]).mock(return_value=Response(200, json=JWKS))
return respx_mock.get(WELL_KNOWN["jwks_uri"]).mock(side_effect=jwk_endpoint_response)


def test_get_verification_key_ias_should_return_key(well_known_endpoint_mock, jwk_endpoint_mock):
from sap.xssec.key_cache_v2 import get_verification_key_ias, key_cache
key_cache.clear()
pem_key = get_verification_key_ias(PAYLOAD["iss"], PAYLOAD["zone_uuid"], HEADER["kid"])
pem_key = get_verification_key_ias(**VERIFICATION_KEY_PARAMS)
assert well_known_endpoint_mock.called
assert jwk_endpoint_mock.called
jwk = next(filter(lambda k: k["kid"] == HEADER["kid"], JWKS["keys"]))
Expand All @@ -70,9 +85,33 @@ def test_get_verification_key_ias_should_cache_key(well_known_endpoint_mock, jwk
from sap.xssec.key_cache_v2 import get_verification_key_ias, key_cache
key_cache.clear()
for _ in range(0, 10):
get_verification_key_ias(PAYLOAD["iss"], PAYLOAD["zone_uuid"], HEADER["kid"])
assert 1 == well_known_endpoint_mock.call_count
assert 1 == jwk_endpoint_mock.call_count
get_verification_key_ias(**VERIFICATION_KEY_PARAMS)
assert 1 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count

for _ in range(0, 10):
get_verification_key_ias(**merge(VERIFICATION_KEY_PARAMS, {"app_tid": "another-app-tid"}))
assert 2 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count

for _ in range(0, 10):
get_verification_key_ias(**merge(VERIFICATION_KEY_PARAMS, {"azp": "another-azp"}))
assert 3 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count

for _ in range(0, 10):
get_verification_key_ias(**merge(VERIFICATION_KEY_PARAMS, {"client_id": "another-client-id"}))
assert 4 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count

for _ in range(0, 10):
get_verification_key_ias(**merge(VERIFICATION_KEY_PARAMS, {"kid": "another-kid"}))
assert 5 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count


def test_get_verification_key_ias_should_throw_error_for_missing_key(well_known_endpoint_mock, jwk_endpoint_mock):
from sap.xssec.key_cache_v2 import get_verification_key_ias, key_cache
key_cache.clear()
for _ in range(0, 10):
with pytest.raises(ValueError):
get_verification_key_ias(**merge(VERIFICATION_KEY_PARAMS, {"kid": "non-existing-kid"}))
assert 10 == well_known_endpoint_mock.call_count == jwk_endpoint_mock.call_count


def test_get_verification_key_ias_should_raise_http_error(respx_mock):
Expand All @@ -81,4 +120,4 @@ def test_get_verification_key_ias_should_raise_http_error(respx_mock):
from sap.xssec.key_cache_v2 import get_verification_key_ias, key_cache
key_cache.clear()
with pytest.raises(HTTPStatusError):
get_verification_key_ias(PAYLOAD["iss"], PAYLOAD["zone_uuid"], HEADER["kid"])
get_verification_key_ias(**VERIFICATION_KEY_PARAMS)
35 changes: 25 additions & 10 deletions tests/test_xssec.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,22 @@ def test_get_verification_key_from_uaa(self, mock_requests):
sec_context.get_clone_service_instance_id(), 'abcd1234')
self.assertEqual(
sec_context.get_additional_auth_attribute('external_group'), 'domaingroup1')
mock_requests.assert_called_once_with("https://api.cf.test.com", timeout=constants.HTTP_TIMEOUT_IN_SECONDS)
mock_requests.assert_called_once_with("https://api.cf.test.com/token_keys?zid=test-idz",
timeout=constants.HTTP_TIMEOUT_IN_SECONDS)

def test_not_trusted_jku(self):
@patch('httpx.get')
def test_composed_jku_with_uaadomain(self, mock_requests):
from sap.xssec.key_cache import KeyCache
xssec.SecurityContextXSUAA.verificationKeyCache = KeyCache()

with self.assertRaises(RuntimeError) as e:
xssec.create_security_context(sign(jwt_payloads.USER_TOKEN), uaa_configs.VALID['uaa_no_verification_key_other_domain'])
mock = MagicMock()
mock_requests.return_value = mock
mock.json.return_value = HTTP_SUCCESS

self.assertEqual("JKU of token is not trusted", str(e.exception),)
xssec.create_security_context(
sign(jwt_payloads.USER_TOKEN), uaa_configs.VALID['uaa_no_verification_key_other_domain'])
mock_requests.assert_called_once_with("https://api.cf2.test.com/token_keys?zid=test-idz",
timeout=constants.HTTP_TIMEOUT_IN_SECONDS)

def test_valid_xsa_token_attributes(self):
''' valid client credentials token (with attributes) '''
Expand All @@ -475,7 +483,6 @@ def test_valid_xsa_token_attributes(self):
self.assertEqual(
sec_context.get_logon_name(), 'ADMIN')


def test_valid_xsa_token_with_newlines(self):
''' valid client credentials token (with attributes) '''
sec_context = xssec.create_security_context(
Expand All @@ -484,12 +491,20 @@ def test_valid_xsa_token_with_newlines(self):
self.assertEqual(
sec_context.get_logon_name(), 'ADMIN')

def test_invalid_jku_in_token_header(self):
@patch('httpx.get')
def test_ignored_invalid_jku_in_token_header(self, mock_requests):
from sap.xssec.key_cache import KeyCache
xssec.SecurityContextXSUAA.verificationKeyCache = KeyCache()

uaa_config = uaa_configs.VALID['uaa']
token = sign(jwt_payloads.USER_TOKEN, headers={
"jku": 'http://ana.ondemandh.com\\\\\\\\\\\\\\\\@' + uaa_config['uaadomain'],
"kid": "key-id-0"
})
with self.assertRaises(RuntimeError) as e:
xssec.create_security_context(token, uaa_config)
self.assertEqual("JKU of token is not trusted", str(e.exception),)
mock = MagicMock()
mock_requests.return_value = mock
mock.json.return_value = HTTP_SUCCESS

xssec.create_security_context(token, uaa_config)
mock_requests.assert_called_once_with("https://api.cf.test.com/token_keys?zid=test-idz",
timeout=constants.HTTP_TIMEOUT_IN_SECONDS)
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4.0.1
4.1.0

0 comments on commit 0c55482

Please sign in to comment.