diff --git a/lemur/auth/views.py b/lemur/auth/views.py index 27c9ca6a9b..64c4266cbd 100644 --- a/lemur/auth/views.py +++ b/lemur/auth/views.py @@ -129,15 +129,16 @@ def validate_id_token(id_token, client_id, jwks_url): # validate your token based on the key it was signed with try: - jwt.decode( + id_token_decoded = jwt.decode( id_token, secret.decode("utf-8"), algorithms=[algo], audience=client_id ) + return id_token_decoded, None except jwt.DecodeError: - return dict(message="Token is invalid"), 401 + return None, dict(message="Token is invalid"), 401 except jwt.ExpiredSignatureError: - return dict(message="Token has expired"), 401 + return None, dict(message="Token has expired"), 401 except jwt.InvalidTokenError: - return dict(message="Token is invalid"), 401 + return None, dict(message="Token is invalid"), 401 def retrieve_user(user_api_url, access_token): @@ -152,7 +153,9 @@ def retrieve_user(user_api_url, access_token): headers = {} - if current_app.config.get("PING_INCLUDE_BEARER_TOKEN"): + if current_app.config.get("PING_INCLUDE_BEARER_TOKEN") and "ping" in current_app.config.get("ACTIVE_PROVIDERS"): + headers = {"Authorization": f"Bearer {access_token}"} + else: headers = {"Authorization": f"Bearer {access_token}"} # retrieve information about the current user. @@ -221,7 +224,7 @@ def create_user_roles(profile: dict) -> list[str]: # If the IDP_GROUPS_TO_ROLES is empty or not set, nothing happens. idp_group_to_role_map = current_app.config.get("IDP_ROLES_MAPPING", {}) matched_roles = [ - idp_group_to_role_map[role] for role in profile.get(idp_groups_key, []) if role in idp_group_to_role_map + role_service.get_by_name(idp_group_to_role_map[role]) for role in profile.get(idp_groups_key, []) if role in idp_group_to_role_map ] roles.extend(matched_roles) @@ -536,7 +539,7 @@ def post(self): ) jwks_url = current_app.config.get("PING_JWKS_URL") - error_code = validate_id_token(id_token, args["clientId"], jwks_url) + id_token_decoded, error_code = validate_id_token(id_token, args["clientId"], jwks_url) if error_code: return error_code @@ -602,12 +605,12 @@ def post(self): ) jwks_url = current_app.config.get("OAUTH2_JWKS_URL") - error_code = validate_id_token(id_token, args["clientId"], jwks_url) + id_token_decoded, error_code = validate_id_token(id_token, args["clientId"], jwks_url) if error_code: return error_code user, profile = retrieve_user(user_api_url, access_token) - roles = create_user_roles(profile) + roles = create_user_roles(id_token_decoded) user = update_user(user, profile, roles) if not user.active: