Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible config and event polling in receiver #59

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions examples/receiver/receiver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
# that can be found in the LICENSE file.

import asyncio
from asyncio import events
import socket
import sys
import logging
import time
import requests
import threading
import os
import json
import threading
from logging import Logger
from logging.config import dictConfig

import urllib
from pathlib import Path
from typing import Any
import requests
from flask import Flask, request
from http.server import HTTPServer, BaseHTTPRequestHandler
from logging.config import dictConfig

from .client import TransmitterClient


Expand All @@ -34,6 +32,18 @@ async def wait_until_available(host, port):
await asyncio.sleep(1)


def poll_events_continuously(client: TransmitterClient, poll_url: str, logger: Logger):
while True:
more_avaliable = True
while more_avaliable:
rsp = client.poll_events(poll_url)
for event in rsp['sets'].values():
logger.info(json.dumps(event, indent=2))
more_avaliable = rsp['moreAvailable']

time.sleep(5)


def create_app(config_filename: str = "config.cfg"):
# Define a flask app that handles the push requests
dictConfig({
Expand All @@ -57,8 +67,8 @@ def create_app(config_filename: str = "config.cfg"):
verify = app.config.get("VERIFY", True)

# Wait for transmitter to be available
transmitter_url = app.config["TRANSMITTER_URL"]
asyncio.run(wait_until_available(urllib.parse.urlparse(transmitter_url).netloc, 443))
transmitter_url = "https://" + app.config["TRANSMITTER_HOST"]
asyncio.run(wait_until_available(app.config["TRANSMITTER_HOST"], 443))

bearer = app.config.get('BEARER')
if not bearer:
Expand All @@ -71,14 +81,28 @@ def create_app(config_filename: str = "config.cfg"):
client = TransmitterClient(transmitter_url, app.config["AUDIENCE"], bearer, verify)
client.get_endpoints()
client.get_jwks()
client.configure_stream(f"{app.config['RECEIVER_URL']}/event")

client.configure_stream(app.config['STREAM_CONFIG'])

for subject in app.config["SUBJECTS"]:
client.add_subject(subject)

if client.stream_config['delivery']['method'].endswith('poll'):
poll_url = client.stream_config['delivery']['endpoint_url']
# Need to replace domain name in endpoint_url because there are different endpoint names
# for reciever and shared_signals_guide
poll_url_parsed = urllib.parse.urlparse(poll_url)
poll_url_parsed = poll_url_parsed._replace(netloc=app.config["TRANSMITTER_HOST"])
poll_url = poll_url_parsed.geturl()

thread = threading.Thread(target=poll_events_continuously,
args=(client, poll_url, app.logger))
thread.start()

@app.route('/event', methods=['POST'])
def receive_event():
body = request.get_data()
event = client.decode_body(body)
event = client.decode_event(body)
app.logger.info(json.dumps(event, indent=2))
return "", 202

Expand Down
37 changes: 25 additions & 12 deletions examples/receiver/receiver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# that can be found in the LICENSE file.

import uuid
from typing import Union, Any
from typing import Optional, Union, Any

import requests
import jwt
from jwcrypto.jwk import JWKSet
Expand All @@ -18,6 +19,7 @@ def __init__(self, transmitter_hostname: str, audience: str, bearer: str, verify
self.audience = audience
self.auth = {"Authorization": f"Bearer {bearer}"}
self.verify = verify
self.acks = []

def get_endpoints(self):
sse_config_response = requests.get(
Expand All @@ -30,7 +32,7 @@ def get_jwks(self):
jwks_response.raise_for_status()
self.jwks = JWKSet.from_json(jwks_response.text)

def decode_body(self, body: Union[str, bytes]):
def decode_event(self, body: Union[str, bytes]):
kid = jwt.get_unverified_header(body)["kid"]
jwk = self.jwks.get_key(kid)
key = jwt.PyJWK(jwk).key
Expand All @@ -42,20 +44,12 @@ def decode_body(self, body: Union[str, bytes]):
audience=self.audience,
)

def configure_stream(self, endpoint_url: str):
def configure_stream(self, config: dict[str, Any]):
""" Configure stream and return the current config """
config_response = requests.post(
url=self.sse_config["configuration_endpoint"],
verify=self.verify,
json={
'delivery': {
'method': 'https://schemas.openid.net/secevent/risc/delivery-method/push',
'endpoint_url': endpoint_url,
},
'events_requested': [
'https://schemas.openid.net/secevent/risc/event-type/credential-compromise',
]
},
json=config,
headers=self.auth,
)
config_response.raise_for_status()
Expand All @@ -77,3 +71,22 @@ def request_verification(self):
json={'state': uuid.uuid4().hex},
headers=self.auth,
)

def poll_events(self, poll_url: str, max_events: Optional[int] = None) -> dict[str, Any]:
""" Poll events """
rsp = requests.post(
url=poll_url,
verify=self.verify,
json={
'max_events': max_events,
'acks': self.acks,
},
headers=self.auth,
)

body = rsp.json()
body['sets'] = { id: self.decode_event(event) for id, event in body['sets'].items() }

self.acks = [id for id in body['sets']]

return body
19 changes: 17 additions & 2 deletions examples/receiver/receiver/config.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
TRANSMITTER_HOST = "transmitter"
RECEIVER_HOST = "receiver:5003"

AUDIENCE = "http://example_receiver"
TRANSMITTER_URL = "https://transmitter"
VERIFY = False
RECEIVER_URL = "http://receiver:5003"

SUBJECTS = [
{
"format": "email",
"email": "[email protected]"
}
]

STREAM_CONFIG = {
# 'delivery': {
# 'method': 'https://schemas.openid.net/secevent/risc/delivery-method/push',
# 'endpoint_url': 'http://' + RECEIVER_HOST + "/event"
# },
'delivery': {
'method': 'https://schemas.openid.net/secevent/risc/delivery-method/poll',
},
'events_requested': [
'https://schemas.openid.net/secevent/risc/event-type/credential-compromise',
]
}
4 changes: 4 additions & 0 deletions examples/transmitter/swagger_server/business_logic/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def update_config(self, new_config: StreamConfiguration, save: bool=True) -> Str

config['events_delivered'] = list(supported.intersection(requested))
self.config = StreamConfiguration.parse_obj(config)

if isinstance(self.config.delivery, PollDeliveryMethod):
self.config.delivery.endpoint_url = POLL_ENDPOINT

if save:
self.save()
return self
Expand Down