From e1368525d808d7fbffb25b15da069054f0416f80 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 15 Apr 2024 12:05:48 -0400 Subject: [PATCH] feat: add `Head_middleware.t`; accept it for SSE/websocket --- src/core/server.ml | 72 +++++++++++++++++++++++++++++---------------- src/core/server.mli | 15 ++++++++++ 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/src/core/server.ml b/src/core/server.ml index 221ebefb..82e1a6f8 100644 --- a/src/core/server.ml +++ b/src/core/server.ml @@ -9,6 +9,18 @@ module Middleware = struct let[@inline] nil : t = fun h -> h end +module Head_middleware = struct + type t = { handle: 'a. 'a Request.t -> 'a Request.t } + + let[@inline] apply (self : t) req = self.handle req + let[@inline] apply' req (self : t) = self.handle req + + let to_middleware (self : t) : Middleware.t = + fun h req ~resp -> + let req = self.handle req in + h req ~resp +end + (* a request handler. handles a single request. *) type cb_path_handler = IO.Output.t -> Middleware.handler @@ -44,7 +56,7 @@ end type upgrade_handler = (module UPGRADE_HANDLER) -exception Upgrade of unit Request.t * upgrade_handler +exception Upgrade of Head_middleware.t list * unit Request.t * upgrade_handler module type IO_BACKEND = sig val init_addr : unit -> string @@ -60,12 +72,12 @@ end type handler_result = | Handle of (int * Middleware.t) list * cb_path_handler | Fail of resp_error - | Upgrade of upgrade_handler + | Upgrade of Head_middleware.t list * upgrade_handler let unwrap_handler_result req = function | Handle (l, h) -> l, h | Fail (c, s) -> raise (Bad_req (c, s)) - | Upgrade up -> raise (Upgrade (req, up)) + | Upgrade (l, up) -> raise (Upgrade (l, req, up)) type t = { backend: (module IO_BACKEND); @@ -184,12 +196,13 @@ let[@inline] _opt_iter ~f o = exception Exit_SSE -let add_route_server_sent_handler ?accept self route f = +let add_route_server_sent_handler ?accept ?(middlewares = []) self route f = let tr_req (oc : IO.Output.t) req ~resp f = let req = Pool.with_resource self.bytes_pool @@ fun bytes -> Request.read_body_full ~bytes req in + let req = List.fold_left Head_middleware.apply' req middlewares in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in @@ -238,7 +251,8 @@ let add_route_server_sent_handler ?accept self route f = in add_route_handler_ self ?accept ~meth:`GET route ~tr_req f -let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit = +let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = []) + (self : t) route f : unit = let ph req : handler_result option = if req.Request.meth <> `GET then None @@ -246,7 +260,7 @@ let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit = match accept req with | Ok () -> (match Route.Private_.eval req.Request.path_components route f with - | Some up -> Some (Upgrade up) + | Some up -> Some (Upgrade (middlewares, up)) | None -> None (* path didn't match *)) | Error err -> Some (Fail err) ) @@ -347,9 +361,19 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = Response.Private_.output_ ~bytes:bytes_res oc resp in - let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit = - Log.debug (fun k -> k "upgrade connection"); + let handle_upgrade ~(middlewares : Head_middleware.t list) req + (module UP : UPGRADE_HANDLER) : unit = try + Log.debug (fun k -> k "upgrade connection"); + + let send_resp resp = + log_response req resp; + Response.Private_.output_ ~bytes:bytes_res oc resp + in + + (* apply head middlewares *) + let req = List.fold_left Head_middleware.apply' req middlewares in + (* check headers *) (match Request.get_header req "connection" with | Some str when string_as_list_contains_ str "Upgrade" -> () @@ -364,18 +388,15 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = | Error msg -> (* fail the upgrade *) Log.error (fun k -> k "upgrade failed: %s" msg); - let resp = Response.make_raw ~code:429 "upgrade required" in - log_response req resp; - Response.Private_.output_ ~bytes:bytes_res oc resp + send_resp @@ Response.make_raw ~code:429 "upgrade required" | Ok (headers, handshake_st) -> (* send the upgrade reply *) let headers = [ "connection", "upgrade"; "upgrade", UP.name ] @ headers in - let resp = Response.make_string ~code:101 ~headers (Ok "") in - log_response req resp; - Response.Private_.output_ ~bytes:bytes_res oc resp; + send_resp @@ Response.make_string ~code:101 ~headers (Ok ""); + (* handshake successful, proceed with the upgrade handler *) UP.handle_connection handshake_st ic oc with e -> let bt = Printexc.get_raw_backtrace () in @@ -384,6 +405,15 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = let continue = ref true in + (* merge per-request middlewares with the server-global middlewares *) + let get_middlewares ~handler_middlewares () : _ list = + let global_middlewares = Lazy.force self.middlewares_sorted in + if handler_middlewares = [] then + global_middlewares + else + sort_middlewares_ (List.rev_append handler_middlewares self.middlewares) + in + let handle_one_req () = match let buf = Buf.of_bytes bytes_req in @@ -422,15 +452,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = | Some s -> bad_reqf 417 "unknown expectation %s" s | None -> ()); - (* merge per-request middlewares with the server-global middlewares *) - let global_middlewares = Lazy.force self.middlewares_sorted in - let all_middlewares = - if handler_middlewares = [] then - global_middlewares - else - sort_middlewares_ - (List.rev_append handler_middlewares self.middlewares) - in + let all_middlewares = get_middlewares ~handler_middlewares () in (* apply middlewares *) let handler oc = @@ -484,10 +506,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = handle_one_req () done with - | Upgrade (req, up) -> + | Upgrade (middlewares, req, up) -> (* upgrades take over the whole connection, we won't process any further request *) - handle_upgrade req up + handle_upgrade ~middlewares req up | e -> let bt = Printexc.get_raw_backtrace () in handle_exn e bt diff --git a/src/core/server.mli b/src/core/server.mli index b79284f3..313c76a6 100644 --- a/src/core/server.mli +++ b/src/core/server.mli @@ -34,6 +34,19 @@ module Middleware : sig (** Trivial middleware that does nothing. *) end +(** A middleware that only considers the request's head+headers. + + These middlewares are simpler than full {!Middleware.t} and + work in more contexts. + @since NEXT_RELEASE *) +module Head_middleware : sig + type t = { handle: 'a. 'a Request.t -> 'a Request.t } + (** A handler that takes the request, without its body, + and possibly modifies it. *) + + val to_middleware : t -> Middleware.t +end + (** {2 Main Server type} *) type t @@ -219,6 +232,7 @@ type server_sent_generator = (module SERVER_SENT_GENERATOR) val add_route_server_sent_handler : ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Head_middleware.t list -> t -> ('a, string Request.t -> server_sent_generator -> unit) Route.t -> 'a -> @@ -270,6 +284,7 @@ type upgrade_handler = (module UPGRADE_HANDLER) val add_upgrade_handler : ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Head_middleware.t list -> t -> ('a, upgrade_handler) Route.t -> 'a ->