From c75a377e2e4f2d03b5500aa10ccc6e367a7caf3f Mon Sep 17 00:00:00 2001 From: James Graham Date: Fri, 24 Jan 2025 10:50:09 +0000 Subject: [PATCH] Add serverCertificateHashes test server Add a second webtransport server, in order to test connection with a server, that has a self-signed certificate together with serverCertificateHashes. See https://github.com/web-platform-tests/rfcs/pull/216 --- tools/serve/serve.py | 33 +++++++++++-- .../webtransport/h3/webtransport_h3_server.py | 48 ++++++++++++++++--- tools/webtransport/requirements.txt | 1 + tools/wptrunner/wptrunner/environment.py | 48 ++++++++++++++++++- tools/wptserve/wptserve/config.py | 1 + tools/wptserve/wptserve/pipes.py | 2 + .../webtransport-test-helpers.sub.js | 9 +++- 7 files changed, 130 insertions(+), 12 deletions(-) diff --git a/tools/serve/serve.py b/tools/serve/serve.py index 6d385b34ac6da2..104c2e331ac9bc 100644 --- a/tools/serve/serve.py +++ b/tools/serve/serve.py @@ -15,6 +15,7 @@ import traceback import urllib import uuid +import datetime from collections import defaultdict, OrderedDict from io import IOBase from itertools import chain, product @@ -991,7 +992,7 @@ def start_servers(logger, host, ports, paths, routes, bind_address, config, continue # Skip WebTransport over HTTP/3 server unless if is enabled explicitly. - if scheme == 'webtransport-h3' and not kwargs.get("webtransport_h3"): + if scheme in ['webtransport-h3', 'webtransport-h3-cert-hash'] and not kwargs.get("webtransport_h3"): continue for port in ports: @@ -1009,6 +1010,7 @@ def start_servers(logger, host, ports, paths, routes, bind_address, config, "ws": start_ws_server, "wss": start_wss_server, "webtransport-h3": start_webtransport_h3_server, + "webtransport-h3-cert-hash": start_webtransport_h3_server_cert_hash, }[scheme] server_proc = ServerProc(mp_context, scheme=scheme) @@ -1174,18 +1176,40 @@ def start_webtransport_h3_server(logger, host, port, paths, routes, bind_address try: # TODO(bashi): Move the following import to the beginning of this file # once WebTransportH3Server is enabled by default. - from webtransport.h3.webtransport_h3_server import WebTransportH3Server # type: ignore + from webtransport.h3.webtransport_h3_server import WebTransportH3Server, WebTransportCertificateGeneration # type: ignore return WebTransportH3Server(host=host, port=port, doc_root=paths["doc_root"], + cert_mode=WebTransportCertificateGeneration.USEPREGENERATED, cert_path=config.ssl_config["cert_path"], key_path=config.ssl_config["key_path"], - logger=logger) + logger=logger, + cert_hash_info=None + ) except Exception as error: logger.critical( f"Failed to start WebTransport over HTTP/3 server: {error}") sys.exit(0) +def start_webtransport_h3_server_cert_hash(logger, host, port, paths, routes, bind_address, config, **kwargs): + try: + # TODO(bashi): Move the following import to the beginning of this file + # once WebTransportH3Server is enabled by default. + from webtransport.h3.webtransport_h3_server import WebTransportH3Server, WebTransportCertificateGeneration + return WebTransportH3Server(host=host, + port=port, + doc_root=paths["doc_root"], + cert_mode=WebTransportCertificateGeneration.GENERATEDVALIDSERVERCERTIFICATEHASHCERT, + cert_path=None, + key_path=None, + logger=logger, + cert_hash_info=config["cert_hash_info"] + ) + except Exception as error: + logger.critical( + f"Failed to start WebTransport over HTTP/3 server with certificate hashes: {error}") + sys.exit(0) + def start(logger, config, routes, mp_context, log_handlers, **kwargs): host = config["server_host"] @@ -1249,6 +1273,7 @@ class ConfigBuilder(config.ConfigBuilder): "ws": ["auto"], "wss": ["auto"], "webtransport-h3": ["auto"], + "webtransport-h3-cert-hash": ["auto"], }, "check_subdomains": True, "bind_address": True, @@ -1372,7 +1397,7 @@ def get_parser(): parser.add_argument("--no-h2", action="store_false", dest="h2", default=None, help="Disable the HTTP/2.0 server") parser.add_argument("--webtransport-h3", action="store_true", - help="Enable WebTransport over HTTP/3 server") + help="Enable WebTransport over HTTP/3 servers") parser.add_argument("--exit-after-start", action="store_true", help="Exit after starting servers") parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") diff --git a/tools/webtransport/h3/webtransport_h3_server.py b/tools/webtransport/h3/webtransport_h3_server.py index 2dd8f645551d63..6ab8a81091cf17 100644 --- a/tools/webtransport/h3/webtransport_h3_server.py +++ b/tools/webtransport/h3/webtransport_h3_server.py @@ -8,10 +8,13 @@ import sys import threading import traceback -from enum import IntEnum +from enum import IntEnum, Enum from urllib.parse import urlparse from typing import Any, Dict, List, Optional, Tuple, cast +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + # TODO(bashi): Remove import check suppressions once aioquic dependency is resolved. from aioquic.buffer import Buffer # type: ignore from aioquic.asyncio import QuicConnectionProtocol, serve # type: ignore @@ -31,6 +34,7 @@ from tools import localpaths # noqa: F401 from wptserve import stash from .capsule import H3Capsule, H3CapsuleDecoder, CapsuleType +from http.server import BaseHTTPRequestHandler, HTTPServer """ A WebTransport over HTTP/3 server for testing. @@ -499,6 +503,16 @@ def add(self, ticket: SessionTicket) -> None: def pop(self, label: bytes) -> Optional[SessionTicket]: return self.tickets.pop(label, None) +class WebTransportCertificateGeneration(Enum): + """ + Specify, if the server should generate a certificate or use an existing certificate + USEPREGENERATED: use existing certificate + GENERATEDVALIDSERVERCERTIFICATEHASHCERT: generate a certificate compatible to server cert hashes + """ + USEPREGENERATED = 1, + GENERATEDVALIDSERVERCERTIFICATEHASHCERT = 2 +# TODO add cases for invalid certificates + class WebTransportH3Server: """ @@ -507,18 +521,31 @@ class WebTransportH3Server: :param host: Host from which to serve. :param port: Port from which to serve. :param doc_root: Document root for serving handlers. + :paran cert_mode: The used certificate mode can be + USEPREGENERATED or GENERATEDVALIDSERVERCERTIFICATEHASHCERT :param cert_path: Path to certificate file to use. :param key_path: Path to key file to use. :param logger: a Logger object for this server. """ - def __init__(self, host: str, port: int, doc_root: str, cert_path: str, - key_path: str, logger: Optional[logging.Logger]) -> None: + def __init__(self, host: str, port: int, doc_root: str, cert_mode: WebTransportCertificateGeneration, + cert_path: Optional[str], key_path: Optional[str], logger: Optional[logging.Logger], + cert_hash_info: Optional[Dict]) -> None: self.host = host self.port = port self.doc_root = doc_root - self.cert_path = cert_path - self.key_path = key_path + if cert_path is not None: + self.cert_path = cert_path + if key_path is not None: + self.key_path = key_path + if cert_hash_info is not None: + self.cert_hash_info = cert_hash_info + self.cert_mode = cert_mode + if (cert_path is None or key_path is None) and cert_mode == WebTransportCertificateGeneration.USEPREGENERATED: + raise ValueError("Both cert_path and key_path must be provided, if cert_mode is USEPREGENERATED") + if (cert_hash_info is None and cert_mode == WebTransportCertificateGeneration.GENERATEDVALIDSERVERCERTIFICATEHASHCERT): + raise ValueError("cert_hash_info must be provided, if cert_mode is GENERATEDVALIDSERVERCERTIFICATEHASHCERT") + self.started = False global _doc_root _doc_root = self.doc_root @@ -551,7 +578,16 @@ def _start_on_server_thread(self) -> None: _logger.info("Starting WebTransport over HTTP/3 server on %s:%s", self.host, self.port) - configuration.load_cert_chain(self.cert_path, self.key_path) + if self.cert_mode == WebTransportCertificateGeneration.USEPREGENERATED: + configuration.load_cert_chain(self.cert_path, self.key_path) + else: # GENERATEDVALIDSERVERCERTIFICATEHASHCERT case + configuration.private_key = serialization.load_pem_private_key(self.cert_hash_info["private_key"], + password=None + ) + configuration.certificate = x509.load_pem_x509_certificate(self.cert_hash_info["certificate"]) + configuration.certificate_chain = [] + + ticket_store = SessionTicketStore() diff --git a/tools/webtransport/requirements.txt b/tools/webtransport/requirements.txt index 9d6dfa350f9ff3..884dfad8a53dd4 100644 --- a/tools/webtransport/requirements.txt +++ b/tools/webtransport/requirements.txt @@ -1 +1,2 @@ aioquic==1.2.0 +cryptography diff --git a/tools/wptrunner/wptrunner/environment.py b/tools/wptrunner/wptrunner/environment.py index 868ed8f91e3caa..49a9ba90013c9b 100644 --- a/tools/wptrunner/wptrunner/environment.py +++ b/tools/wptrunner/wptrunner/environment.py @@ -9,8 +9,14 @@ import socket import sys import time +import datetime from typing import Optional +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.x509.oid import NameOID +from cryptography import x509 + import mozprocess from mozlog import get_default_logger, handlers from mozlog.structuredlog import StructuredLogger @@ -46,6 +52,37 @@ def do_delayed_imports(logger, test_paths): (", ".join(failed), serve_root)) sys.exit(1) +def generate_hash_certificate(host: str) -> str: + private_key = ec.generate_private_key(ec.SECP256R1()) + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, "DE"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Berlin"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Berlin"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Wpt tests"), + x509.NameAttribute(NameOID.COMMON_NAME, host), + ]) + now = datetime.datetime.now(datetime.timezone.utc) + certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=13)) + .sign(private_key, hashes.SHA256()) + ) + fingerprint = certificate.fingerprint(hashes.SHA256()) + server_certificate_hash = ":".join(f"{byte:02x}" for byte in fingerprint) + return { "certificate": certificate.public_bytes( + encoding=serialization.Encoding.PEM + ), + "private_key": private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()), + "hash": server_certificate_hash + } def serve_path(test_paths): return test_paths["/"].tests_path @@ -150,7 +187,8 @@ def __enter__(self): self.get_routes(), mp_context=mpcontext.get_context(), log_handlers=[server_log_handler], - webtransport_h3=self.enable_webtransport) + webtransport_h3=self.enable_webtransport, + webtransport_h3_cert_hash=self.enable_webtransport) if self.options.get("supports_debugger") and self.debug_info and self.debug_info.interactive: self._stack.enter_context(self.ignore_interrupts()) @@ -197,6 +235,7 @@ def build_config(self): "wss": [8889], "h2": [9000], "webtransport-h3": [11000], + "webtransport-h3-cert-hash": [11001], } config.ports = ports @@ -221,6 +260,8 @@ def build_config(self): config.doc_root = serve_path(self.test_paths) config.inject_script = self.inject_script + config.cert_hash_info = generate_hash_certificate(config.server_host) + if self.suppress_handler_traceback is not None: config.logging["suppress_handler_traceback"] = self.suppress_handler_traceback @@ -323,10 +364,15 @@ def test_servers(self): for port, server in self.servers.get("webtransport-h3", []): if not webtranport_h3_server_is_running(host, port, timeout=5): pending.append((host, port)) + for port, server in self.servers.get("webtransport-h3-cert-hash", []): + if not webtranport_h3_server_is_running(host, port, timeout=5): + pending.append((host, port)) for scheme, servers in self.servers.items(): if scheme == "webtransport-h3": continue + if scheme == "webtransport-h3-cert-hash": + continue for port, server in servers: s = socket.socket() s.settimeout(0.1) diff --git a/tools/wptserve/wptserve/config.py b/tools/wptserve/wptserve/config.py index 50e20f05f0b24f..77b3372da377e8 100644 --- a/tools/wptserve/wptserve/config.py +++ b/tools/wptserve/wptserve/config.py @@ -128,6 +128,7 @@ class ConfigBuilder: _default = { "browser_host": "localhost", + "certificate_hash": {}, "alternate_hosts": {}, "doc_root": os.path.dirname("__file__"), "server_host": None, diff --git a/tools/wptserve/wptserve/pipes.py b/tools/wptserve/wptserve/pipes.py index 84b17c12285307..35323b0f8ad6c3 100644 --- a/tools/wptserve/wptserve/pipes.py +++ b/tools/wptserve/wptserve/pipes.py @@ -478,6 +478,8 @@ def config_replacement(match): value = variables[field] elif hasattr(SubFunctions, field): value = getattr(SubFunctions, field) + elif field == "server_certificate_hash": + value = request.server.config["cert_hash_info"]["hash"] elif field == "headers": value = request.headers elif field == "GET": diff --git a/webtransport/resources/webtransport-test-helpers.sub.js b/webtransport/resources/webtransport-test-helpers.sub.js index 36788699e8bbbe..79723d2a99ed95 100644 --- a/webtransport/resources/webtransport-test-helpers.sub.js +++ b/webtransport/resources/webtransport-test-helpers.sub.js @@ -3,16 +3,23 @@ const HOST = get_host_info().ORIGINAL_HOST; const PORT = '{{ports[webtransport-h3][0]}}'; +const PORT_CERT_HASH = '{{ports[webtransport-h3-cert-hash][0]}}'; const BASE = `https://${HOST}:${PORT}`; +const BASE_CERT_HASH = `https://${HOST}:${PORT_CERT_HASH}`; // Wait for the given number of milliseconds (ms). function wait(ms) { return new Promise(res => step_timeout(res, ms)); } // Create URL for WebTransport session. -function webtransport_url(handler) { +function webtransport_url(handler, options) { + if (options?.cert_hashes) { + return `${BASE_CERT_HASH}/webtransport/handlers/${handler}`; + } return `${BASE}/webtransport/handlers/${handler}`; } +const cert_hash = new Uint8Array('{{server_certificate_hash}}'.split(':').map((el) => parseInt(el, 16))); +const cert_hash_str = '{{server_certificate_hash}}' // Converts WebTransport stream error code to HTTP/3 error code. // https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#section-4.3 function webtransport_code_to_http_code(n) {