Skip to content

Commit

Permalink
fix websocket: properly remember the offset in current frame
Browse files Browse the repository at this point in the history
not doing so means we always unmask from offset 0, which means we might
use the wrong index in the mask if we do, say, `read()=3` followed by
another read: the second one would start from mask[0] instead of
mask[3], producing raw unfiltered garbage.
  • Loading branch information
c-cube committed Apr 5, 2024
1 parent 5301ed4 commit 3393c13
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
50 changes: 34 additions & 16 deletions src/ws/tiny_httpd_ws.ml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ end
module Reader = struct
type state =
| Begin (** At the beginning of a frame *)
| Reading_frame of { mutable remaining_bytes: int }
| Reading_frame of { mutable remaining_bytes: int; mutable num_read: int }
(** Currently reading the payload of a frame with [remaining_bytes]
left to read from the underlying [ic] *)
| Close
Expand Down Expand Up @@ -268,22 +268,26 @@ module Reader = struct
self.header.payload_len self.header.mask);*)
()

external apply_masking_ : key:bytes -> buf:bytes -> int -> int -> unit
external apply_masking_ :
key:bytes -> key_offset:int -> buf:bytes -> int -> int -> unit
= "tiny_httpd_ws_apply_masking"
[@@noalloc]
(** Apply masking to the parsed data *)

let[@inline] apply_masking ~mask_key (buf : bytes) off len : unit =
let[@inline] apply_masking ~mask_key ~mask_offset (buf : bytes) off len : unit
=
assert (
Bytes.length mask_key = 4 && off >= 0 && off + len <= Bytes.length buf);
apply_masking_ ~key:mask_key ~buf off len
Bytes.length mask_key = 4
&& mask_offset >= 0 && off >= 0
&& off + len <= Bytes.length buf);
apply_masking_ ~key:mask_key ~key_offset:mask_offset ~buf off len

let read_body_to_string (self : t) : string =
let len = self.header.payload_len in
let buf = Bytes.create len in
IO.Input.really_input self.ic buf 0 len;
if self.header.mask then
apply_masking ~mask_key:self.header.mask_key buf 0 len;
apply_masking ~mask_key:self.header.mask_key ~mask_offset:0 buf 0 len;
Bytes.unsafe_to_string buf

(** Skip bytes of the body *)
Expand All @@ -303,33 +307,45 @@ module Reader = struct
self.state <- Begin;
read_rec self buf i len
| Reading_frame r ->
Printf.printf "reading len=%d from frame remaining=%d (key=%S)\n%!" len
r.remaining_bytes
(Bytes.unsafe_to_string self.header.mask_key);
let len = min len r.remaining_bytes in
let n = IO.Input.input self.ic buf i len in
Printf.printf "got n=%d bytes\n%!" n;
Printf.printf "in buf: %S\n%!" (Bytes.sub_string buf i n);

(* update state *)
r.remaining_bytes <- r.remaining_bytes - n;
if r.remaining_bytes = 0 then self.state <- Begin;

(* apply masking *)
if self.header.mask then
apply_masking ~mask_key:self.header.mask_key buf i n
apply_masking ~mask_key:self.header.mask_key ~mask_offset:r.num_read buf
i n
else (
Log.error (fun k -> k "websocket: client's frames must be masked");
raise Close_connection
);

(* update state *)
r.remaining_bytes <- r.remaining_bytes - n;
r.num_read <- r.num_read + n;
if r.remaining_bytes = 0 then self.state <- Begin;

Printf.printf "in buf (unmasked): %S\n%!" (Bytes.sub_string buf i n);
n
| Begin ->
read_frame_header self;
Log.debug (fun k ->
k "websocket: read frame of type=%s payload_len=%d"
k "websocket: read frame of type=%s payload_len=%d key=%S"
(Frame_type.show self.header.ty)
self.header.payload_len);
self.header.payload_len
(Bytes.unsafe_to_string self.header.mask_key));

(match self.header.ty with
| 0 ->
(* continuation *)
if self.last_ty = 1 || self.last_ty = 2 then
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len }
Reading_frame
{ remaining_bytes = self.header.payload_len; num_read = 0 }
else (
Log.error (fun k ->
k "continuation frame coming after frame of type %s"
Expand All @@ -340,12 +356,14 @@ module Reader = struct
| 1 ->
(* text *)
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len };
Reading_frame
{ remaining_bytes = self.header.payload_len; num_read = 0 };
read_rec self buf i len
| 2 ->
(* binary *)
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len };
Reading_frame
{ remaining_bytes = self.header.payload_len; num_read = 0 };
read_rec self buf i len
| 8 ->
(* close frame *)
Expand Down
3 changes: 2 additions & 1 deletion src/ws/tiny_httpd_ws.mli
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ val add_route_handler :
(**/**)

module Private_ : sig
val apply_masking : mask_key:bytes -> bytes -> int -> int -> unit
val apply_masking :
mask_key:bytes -> mask_offset:int -> bytes -> int -> int -> unit
end

(**/**)
7 changes: 4 additions & 3 deletions src/ws/tiny_httpd_ws_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
#include <caml/memory.h>
#include <caml/mlvalues.h>

CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _buf,
CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _mask_offset, value _buf,
value _offset, value _len) {
CAMLparam4(_mask_key, _buf, _offset, _len);
CAMLparam5(_mask_key, _mask_offset, _buf, _offset, _len);

char const *mask_key = String_val(_mask_key);
char *buf = Bytes_val(_buf);
intnat mask_offset = Int_val(_mask_offset);
intnat offset = Int_val(_offset);
intnat len = Int_val(_len);

for (intnat i = 0; i < len; ++i) {
unsigned char c = buf[offset + i];
unsigned char c_m = mask_key[i & 0x3];
unsigned char c_m = mask_key[(i + mask_offset) & 0x3];
buf[offset + i] = (unsigned char)(c ^ c_m);
}
CAMLreturn(Val_unit);
Expand Down

0 comments on commit 3393c13

Please sign in to comment.