Skip to content

Commit

Permalink
feat: add Head_middleware.t; accept it for SSE/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
c-cube committed Apr 15, 2024
1 parent 1955406 commit e136852
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 25 deletions.
72 changes: 47 additions & 25 deletions src/core/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ module Middleware = struct
let[@inline] nil : t = fun h -> h
end

module Head_middleware = struct
type t = { handle: 'a. 'a Request.t -> 'a Request.t }

let[@inline] apply (self : t) req = self.handle req
let[@inline] apply' req (self : t) = self.handle req

let to_middleware (self : t) : Middleware.t =
fun h req ~resp ->
let req = self.handle req in
h req ~resp
end

(* a request handler. handles a single request. *)
type cb_path_handler = IO.Output.t -> Middleware.handler

Expand Down Expand Up @@ -44,7 +56,7 @@ end

type upgrade_handler = (module UPGRADE_HANDLER)

exception Upgrade of unit Request.t * upgrade_handler
exception Upgrade of Head_middleware.t list * unit Request.t * upgrade_handler

module type IO_BACKEND = sig
val init_addr : unit -> string
Expand All @@ -60,12 +72,12 @@ end
type handler_result =
| Handle of (int * Middleware.t) list * cb_path_handler
| Fail of resp_error
| Upgrade of upgrade_handler
| Upgrade of Head_middleware.t list * upgrade_handler

let unwrap_handler_result req = function
| Handle (l, h) -> l, h
| Fail (c, s) -> raise (Bad_req (c, s))
| Upgrade up -> raise (Upgrade (req, up))
| Upgrade (l, up) -> raise (Upgrade (l, req, up))

type t = {
backend: (module IO_BACKEND);
Expand Down Expand Up @@ -184,12 +196,13 @@ let[@inline] _opt_iter ~f o =
exception Exit_SSE
let add_route_server_sent_handler ?accept self route f =
let add_route_server_sent_handler ?accept ?(middlewares = []) self route f =
let tr_req (oc : IO.Output.t) req ~resp f =
let req =
Pool.with_resource self.bytes_pool @@ fun bytes ->
Request.read_body_full ~bytes req
in
let req = List.fold_left Head_middleware.apply' req middlewares in
let headers =
ref Headers.(empty |> set "content-type" "text/event-stream")
in
Expand Down Expand Up @@ -238,15 +251,16 @@ let add_route_server_sent_handler ?accept self route f =
in
add_route_handler_ self ?accept ~meth:`GET route ~tr_req f
let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit =
let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = [])
(self : t) route f : unit =
let ph req : handler_result option =
if req.Request.meth <> `GET then
None
else (
match accept req with
| Ok () ->
(match Route.Private_.eval req.Request.path_components route f with
| Some up -> Some (Upgrade up)
| Some up -> Some (Upgrade (middlewares, up))
| None -> None (* path didn't match *))
| Error err -> Some (Fail err)
)
Expand Down Expand Up @@ -347,9 +361,19 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
Response.Private_.output_ ~bytes:bytes_res oc resp
in
let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit =
Log.debug (fun k -> k "upgrade connection");
let handle_upgrade ~(middlewares : Head_middleware.t list) req
(module UP : UPGRADE_HANDLER) : unit =
try
Log.debug (fun k -> k "upgrade connection");
let send_resp resp =
log_response req resp;
Response.Private_.output_ ~bytes:bytes_res oc resp
in
(* apply head middlewares *)
let req = List.fold_left Head_middleware.apply' req middlewares in
(* check headers *)
(match Request.get_header req "connection" with
| Some str when string_as_list_contains_ str "Upgrade" -> ()
Expand All @@ -364,18 +388,15 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
| Error msg ->
(* fail the upgrade *)
Log.error (fun k -> k "upgrade failed: %s" msg);
let resp = Response.make_raw ~code:429 "upgrade required" in
log_response req resp;
Response.Private_.output_ ~bytes:bytes_res oc resp
send_resp @@ Response.make_raw ~code:429 "upgrade required"
| Ok (headers, handshake_st) ->
(* send the upgrade reply *)
let headers =
[ "connection", "upgrade"; "upgrade", UP.name ] @ headers
in
let resp = Response.make_string ~code:101 ~headers (Ok "") in
log_response req resp;
Response.Private_.output_ ~bytes:bytes_res oc resp;
send_resp @@ Response.make_string ~code:101 ~headers (Ok "");
(* handshake successful, proceed with the upgrade handler *)
UP.handle_connection handshake_st ic oc
with e ->
let bt = Printexc.get_raw_backtrace () in
Expand All @@ -384,6 +405,15 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
let continue = ref true in
(* merge per-request middlewares with the server-global middlewares *)
let get_middlewares ~handler_middlewares () : _ list =
let global_middlewares = Lazy.force self.middlewares_sorted in
if handler_middlewares = [] then
global_middlewares
else
sort_middlewares_ (List.rev_append handler_middlewares self.middlewares)
in
let handle_one_req () =
match
let buf = Buf.of_bytes bytes_req in
Expand Down Expand Up @@ -422,15 +452,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
| Some s -> bad_reqf 417 "unknown expectation %s" s
| None -> ());
(* merge per-request middlewares with the server-global middlewares *)
let global_middlewares = Lazy.force self.middlewares_sorted in
let all_middlewares =
if handler_middlewares = [] then
global_middlewares
else
sort_middlewares_
(List.rev_append handler_middlewares self.middlewares)
in
let all_middlewares = get_middlewares ~handler_middlewares () in
(* apply middlewares *)
let handler oc =
Expand Down Expand Up @@ -484,10 +506,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
handle_one_req ()
done
with
| Upgrade (req, up) ->
| Upgrade (middlewares, req, up) ->
(* upgrades take over the whole connection, we won't process
any further request *)
handle_upgrade req up
handle_upgrade ~middlewares req up
| e ->
let bt = Printexc.get_raw_backtrace () in
handle_exn e bt
Expand Down
15 changes: 15 additions & 0 deletions src/core/server.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ module Middleware : sig
(** Trivial middleware that does nothing. *)
end

(** A middleware that only considers the request's head+headers.
These middlewares are simpler than full {!Middleware.t} and
work in more contexts.
@since NEXT_RELEASE *)
module Head_middleware : sig
type t = { handle: 'a. 'a Request.t -> 'a Request.t }
(** A handler that takes the request, without its body,
and possibly modifies it. *)

val to_middleware : t -> Middleware.t
end

(** {2 Main Server type} *)

type t
Expand Down Expand Up @@ -219,6 +232,7 @@ type server_sent_generator = (module SERVER_SENT_GENERATOR)

val add_route_server_sent_handler :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?middlewares:Head_middleware.t list ->
t ->
('a, string Request.t -> server_sent_generator -> unit) Route.t ->
'a ->
Expand Down Expand Up @@ -270,6 +284,7 @@ type upgrade_handler = (module UPGRADE_HANDLER)

val add_upgrade_handler :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?middlewares:Head_middleware.t list ->
t ->
('a, upgrade_handler) Route.t ->
'a ->
Expand Down

0 comments on commit e136852

Please sign in to comment.